diff --git a/.github/workflows/ci_lin.yml b/.github/workflows/ci_lin.yml index 31353ade9..5b11f1c5f 100644 --- a/.github/workflows/ci_lin.yml +++ b/.github/workflows/ci_lin.yml @@ -31,6 +31,7 @@ jobs: - test_suite: features - test_suite: api_coverage - test_suite: behavior_tests + - test_suite: applications steps: - name: Process ENV id: process_env diff --git a/applications/applications.xml b/applications/applications.xml new file mode 100644 index 000000000..2b3379422 --- /dev/null +++ b/applications/applications.xml @@ -0,0 +1,8 @@ + + + + applications test + + + + diff --git a/applications/config/TEMPLATE_llama.xml b/applications/config/TEMPLATE_llama.xml new file mode 100644 index 000000000..d9be47442 --- /dev/null +++ b/applications/config/TEMPLATE_llama.xml @@ -0,0 +1,12 @@ + + + + test + + + + + + + + \ No newline at end of file diff --git a/applications/src/llama/LICENSE b/applications/src/llama/LICENSE new file mode 100644 index 000000000..76f67efdc --- /dev/null +++ b/applications/src/llama/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Georgi Gerganov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/applications/src/llama/common/base64.hpp b/applications/src/llama/common/base64.hpp new file mode 100644 index 000000000..563247a6e --- /dev/null +++ b/applications/src/llama/common/base64.hpp @@ -0,0 +1,392 @@ +/* +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or +distribute this software, either in source code form or as a compiled +binary, for any purpose, commercial or non-commercial, and by any +means. + +In jurisdictions that recognize copyright laws, the author or authors +of this software dedicate any and all copyright interest in the +software to the public domain. We make this dedication for the benefit +of the public at large and to the detriment of our heirs and +successors. We intend this dedication to be an overt act of +relinquishment in perpetuity of all present and future rights to this +software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR +OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. + +For more information, please refer to +*/ + +#ifndef PUBLIC_DOMAIN_BASE64_HPP_ +#define PUBLIC_DOMAIN_BASE64_HPP_ + +#include +#include +#include +#include + +class base64_error : public std::runtime_error +{ +public: + using std::runtime_error::runtime_error; +}; + +class base64 +{ +public: + enum class alphabet + { + /** the alphabet is detected automatically */ + auto_, + /** the standard base64 alphabet is used */ + standard, + /** like `standard` except that the characters `+` and `/` are replaced by `-` and `_` respectively*/ + url_filename_safe + }; + + enum class decoding_behavior + { + /** if the input is not padded, the remaining bits are ignored */ + moderate, + /** if a padding character is encounter decoding is finished */ + loose + }; + + /** + Encodes all the elements from `in_begin` to `in_end` to `out`. + + @warning The source and destination cannot overlap. The destination must be able to hold at least + `required_encode_size(std::distance(in_begin, in_end))`, otherwise the behavior depends on the output iterator. + + @tparam Input_iterator the source; the returned elements are cast to `std::uint8_t` and should not be greater than + 8 bits + @tparam Output_iterator the destination; the elements written to it are from the type `char` + @param in_begin the beginning of the source + @param in_end the ending of the source + @param out the destination iterator + @param alphabet which alphabet should be used + @returns the iterator to the next element past the last element copied + @throws see `Input_iterator` and `Output_iterator` + */ + template + static Output_iterator encode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out, + alphabet alphabet = alphabet::standard) + { + constexpr auto pad = '='; + const char* alpha = alphabet == alphabet::url_filename_safe + ? "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" + : "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + while (in_begin != in_end) { + std::uint8_t i0 = 0, i1 = 0, i2 = 0; + + // first character + i0 = static_cast(*in_begin); + ++in_begin; + + *out = alpha[i0 >> 2 & 0x3f]; + ++out; + + // part of first character and second + if (in_begin != in_end) { + i1 = static_cast(*in_begin); + ++in_begin; + + *out = alpha[((i0 & 0x3) << 4) | (i1 >> 4 & 0x0f)]; + ++out; + } else { + *out = alpha[(i0 & 0x3) << 4]; + ++out; + + // last padding + *out = pad; + ++out; + + // last padding + *out = pad; + ++out; + + break; + } + + // part of second character and third + if (in_begin != in_end) { + i2 = static_cast(*in_begin); + ++in_begin; + + *out = alpha[((i1 & 0xf) << 2) | (i2 >> 6 & 0x03)]; + ++out; + } else { + *out = alpha[(i1 & 0xf) << 2]; + ++out; + + // last padding + *out = pad; + ++out; + + break; + } + + // rest of third + *out = alpha[i2 & 0x3f]; + ++out; + } + + return out; + } + /** + Encodes a string. + + @param str the string that should be encoded + @param alphabet which alphabet should be used + @returns the encoded base64 string + @throws see base64::encode() + */ + static std::string encode(const std::string& str, alphabet alphabet = alphabet::standard) + { + std::string result; + + result.reserve(required_encode_size(str.length()) + 1); + + encode(str.begin(), str.end(), std::back_inserter(result), alphabet); + + return result; + } + /** + Encodes a char array. + + @param buffer the char array + @param size the size of the array + @param alphabet which alphabet should be used + @returns the encoded string + */ + static std::string encode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::standard) + { + std::string result; + + result.reserve(required_encode_size(size) + 1); + + encode(buffer, buffer + size, std::back_inserter(result), alphabet); + + return result; + } + /** + Decodes all the elements from `in_begin` to `in_end` to `out`. `in_begin` may point to the same location as `out`, + in other words: inplace decoding is possible. + + @warning The destination must be able to hold at least `required_decode_size(std::distance(in_begin, in_end))`, + otherwise the behavior depends on the output iterator. + + @tparam Input_iterator the source; the returned elements are cast to `char` + @tparam Output_iterator the destination; the elements written to it are from the type `std::uint8_t` + @param in_begin the beginning of the source + @param in_end the ending of the source + @param out the destination iterator + @param alphabet which alphabet should be used + @param behavior the behavior when an error was detected + @returns the iterator to the next element past the last element copied + @throws base64_error depending on the set behavior + @throws see `Input_iterator` and `Output_iterator` + */ + template + static Output_iterator decode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out, + alphabet alphabet = alphabet::auto_, + decoding_behavior behavior = decoding_behavior::moderate) + { + //constexpr auto pad = '='; + std::uint8_t last = 0; + auto bits = 0; + + while (in_begin != in_end) { + auto c = *in_begin; + ++in_begin; + + if (c == '=') { + break; + } + + auto part = _base64_value(alphabet, c); + + // enough bits for one byte + if (bits + 6 >= 8) { + *out = (last << (8 - bits)) | (part >> (bits - 2)); + ++out; + + bits -= 2; + } else { + bits += 6; + } + + last = part; + } + + // check padding + if (behavior != decoding_behavior::loose) { + while (in_begin != in_end) { + auto c = *in_begin; + ++in_begin; + + if (c != '=') { + throw base64_error("invalid base64 character."); + } + } + } + + return out; + } + /** + Decodes a string. + + @param str the base64 encoded string + @param alphabet which alphabet should be used + @param behavior the behavior when an error was detected + @returns the decoded string + @throws see base64::decode() + */ + static std::string decode(const std::string& str, alphabet alphabet = alphabet::auto_, + decoding_behavior behavior = decoding_behavior::moderate) + { + std::string result; + + result.reserve(max_decode_size(str.length())); + + decode(str.begin(), str.end(), std::back_inserter(result), alphabet, behavior); + + return result; + } + /** + Decodes a string. + + @param buffer the base64 encoded buffer + @param size the size of the buffer + @param alphabet which alphabet should be used + @param behavior the behavior when an error was detected + @returns the decoded string + @throws see base64::decode() + */ + static std::string decode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::auto_, + decoding_behavior behavior = decoding_behavior::moderate) + { + std::string result; + + result.reserve(max_decode_size(size)); + + decode(buffer, buffer + size, std::back_inserter(result), alphabet, behavior); + + return result; + } + /** + Decodes a string inplace. + + @param[in,out] str the base64 encoded string + @param alphabet which alphabet should be used + @param behavior the behavior when an error was detected + @throws base64::decode_inplace() + */ + static void decode_inplace(std::string& str, alphabet alphabet = alphabet::auto_, + decoding_behavior behavior = decoding_behavior::moderate) + { + str.resize(decode(str.begin(), str.end(), str.begin(), alphabet, behavior) - str.begin()); + } + /** + Decodes a char array inplace. + + @param[in,out] str the string array + @param size the length of the array + @param alphabet which alphabet should be used + @param behavior the behavior when an error was detected + @returns the pointer to the next element past the last element decoded + @throws base64::decode_inplace() + */ + static char* decode_inplace(char* str, std::size_t size, alphabet alphabet = alphabet::auto_, + decoding_behavior behavior = decoding_behavior::moderate) + { + return decode(str, str + size, str, alphabet, behavior); + } + /** + Returns the required decoding size for a given size. The value is calculated with the following formula: + + $$ + \lceil \frac{size}{4} \rceil \cdot 3 + $$ + + @param size the size of the encoded input + @returns the size of the resulting decoded buffer; this the absolute maximum + */ + static std::size_t max_decode_size(std::size_t size) noexcept + { + return (size / 4 + (size % 4 ? 1 : 0)) * 3; + } + /** + Returns the required encoding size for a given size. The value is calculated with the following formula: + + $$ + \lceil \frac{size}{3} \rceil \cdot 4 + $$ + + @param size the size of the decoded input + @returns the size of the resulting encoded buffer + */ + static std::size_t required_encode_size(std::size_t size) noexcept + { + return (size / 3 + (size % 3 ? 1 : 0)) * 4; + } + +private: + static std::uint8_t _base64_value(alphabet& alphabet, char c) + { + if (c >= 'A' && c <= 'Z') { + return c - 'A'; + } else if (c >= 'a' && c <= 'z') { + return c - 'a' + 26; + } else if (c >= '0' && c <= '9') { + return c - '0' + 52; + } + + // comes down to alphabet + if (alphabet == alphabet::standard) { + if (c == '+') { + return 62; + } else if (c == '/') { + return 63; + } + } else if (alphabet == alphabet::url_filename_safe) { + if (c == '-') { + return 62; + } else if (c == '_') { + return 63; + } + } // auto detect + else { + if (c == '+') { + alphabet = alphabet::standard; + + return 62; + } else if (c == '/') { + alphabet = alphabet::standard; + + return 63; + } else if (c == '-') { + alphabet = alphabet::url_filename_safe; + + return 62; + } else if (c == '_') { + alphabet = alphabet::url_filename_safe; + + return 63; + } + } + + throw base64_error("invalid base64 character."); + } +}; + +#endif // !PUBLIC_DOMAIN_BASE64_HPP_ diff --git a/applications/src/llama/common/build-info.cpp b/applications/src/llama/common/build-info.cpp new file mode 100644 index 000000000..b1632a1e3 --- /dev/null +++ b/applications/src/llama/common/build-info.cpp @@ -0,0 +1,4 @@ +int LLAMA_BUILD_NUMBER = 1657; +char const *LLAMA_COMMIT = "3c04bf6"; +char const *LLAMA_COMPILER = "cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0"; +char const *LLAMA_BUILD_TARGET = "x86_64-linux-gnu"; diff --git a/applications/src/llama/common/build-info.cpp.in b/applications/src/llama/common/build-info.cpp.in new file mode 100644 index 000000000..0b945aa68 --- /dev/null +++ b/applications/src/llama/common/build-info.cpp.in @@ -0,0 +1,4 @@ +int LLAMA_BUILD_NUMBER = @BUILD_NUMBER@; +char const *LLAMA_COMMIT = "@BUILD_COMMIT@"; +char const *LLAMA_COMPILER = "@BUILD_COMPILER@"; +char const *LLAMA_BUILD_TARGET = "@BUILD_TARGET@"; diff --git a/applications/src/llama/common/common.cpp b/applications/src/llama/common/common.cpp new file mode 100644 index 000000000..93d5483e4 --- /dev/null +++ b/applications/src/llama/common/common.cpp @@ -0,0 +1,1616 @@ +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__APPLE__) && defined(__MACH__) +#include +#include +#endif + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#include +#include +#include +#include +#else +#include +#include +#include +#endif + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +int32_t get_num_physical_cores() { +#ifdef __linux__ + // enumerate the set of thread siblings, num entries is num cores + std::unordered_set siblings; + for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) { + std::ifstream thread_siblings("/sys/devices/system/cpu" + + std::to_string(cpu) + "/topology/thread_siblings"); + if (!thread_siblings.is_open()) { + break; // no more cpus + } + std::string line; + if (std::getline(thread_siblings, line)) { + siblings.insert(line); + } + } + if (!siblings.empty()) { + return static_cast(siblings.size()); + } +#elif defined(__APPLE__) && defined(__MACH__) + int32_t num_physical_cores; + size_t len = sizeof(num_physical_cores); + int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0); + if (result == 0) { + return num_physical_cores; + } + result = sysctlbyname("hw.physicalcpu", &num_physical_cores, &len, NULL, 0); + if (result == 0) { + return num_physical_cores; + } +#elif defined(_WIN32) + //TODO: Implement +#endif + unsigned int n_threads = std::thread::hardware_concurrency(); + return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; +} + +void process_escapes(std::string& input) { + std::size_t input_len = input.length(); + std::size_t output_idx = 0; + + for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) { + if (input[input_idx] == '\\' && input_idx + 1 < input_len) { + switch (input[++input_idx]) { + case 'n': input[output_idx++] = '\n'; break; + case 'r': input[output_idx++] = '\r'; break; + case 't': input[output_idx++] = '\t'; break; + case '\'': input[output_idx++] = '\''; break; + case '\"': input[output_idx++] = '\"'; break; + case '\\': input[output_idx++] = '\\'; break; + case 'x': + // Handle \x12, etc + if (input_idx + 2 < input_len) { + const char x[3] = { input[input_idx + 1], input[input_idx + 2], 0 }; + char *err_p = nullptr; + const long val = std::strtol(x, &err_p, 16); + if (err_p == x + 2) { + input_idx += 2; + input[output_idx++] = char(val); + break; + } + } + // fall through + default: input[output_idx++] = '\\'; + input[output_idx++] = input[input_idx]; break; + } + } else { + input[output_idx++] = input[input_idx]; + } + } + + input.resize(output_idx); +} + +bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { + bool result = true; + try { + if (!gpt_params_parse_ex(argc, argv, params)) { + gpt_print_usage(argc, argv, gpt_params()); + exit(0); + } + } + catch (const std::invalid_argument & ex) { + fprintf(stderr, "%s\n", ex.what()); + gpt_print_usage(argc, argv, gpt_params()); + exit(1); + } + return result; +} + +bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { + bool invalid_param = false; + std::string arg; + const std::string arg_prefix = "--"; + llama_sampling_params & sparams = params.sparams; + + for (int i = 1; i < argc; i++) { + arg = argv[i]; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); + } + + if (arg == "-s" || arg == "--seed") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.seed = std::stoul(argv[i]); + } else if (arg == "-t" || arg == "--threads") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_threads = std::stoi(argv[i]); + if (params.n_threads <= 0) { + params.n_threads = std::thread::hardware_concurrency(); + } + } else if (arg == "-tb" || arg == "--threads-batch") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_threads_batch = std::stoi(argv[i]); + if (params.n_threads_batch <= 0) { + params.n_threads_batch = std::thread::hardware_concurrency(); + } + } else if (arg == "-p" || arg == "--prompt") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.prompt = argv[i]; + } else if (arg == "-e" || arg == "--escape") { + params.escape = true; + } else if (arg == "--prompt-cache") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.path_prompt_cache = argv[i]; + } else if (arg == "--prompt-cache-all") { + params.prompt_cache_all = true; + } else if (arg == "--prompt-cache-ro") { + params.prompt_cache_ro = true; + } else if (arg == "-f" || arg == "--file") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + // store the external file name in params + params.prompt_file = argv[i]; + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); + if (!params.prompt.empty() && params.prompt.back() == '\n') { + params.prompt.pop_back(); + } + } else if (arg == "-n" || arg == "--n-predict") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_predict = std::stoi(argv[i]); + } else if (arg == "--top-k") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.top_k = std::stoi(argv[i]); + } else if (arg == "-c" || arg == "--ctx-size") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_ctx = std::stoi(argv[i]); + } else if (arg == "--rope-freq-base") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_freq_base = std::stof(argv[i]); + } else if (arg == "--rope-freq-scale") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_freq_scale = std::stof(argv[i]); + } else if (arg == "--rope-scaling") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::string value(argv[i]); + /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; } + else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; } + else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; } + else { invalid_param = true; break; } + } else if (arg == "--rope-scale") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_freq_scale = 1.0f/std::stof(argv[i]); + } else if (arg == "--yarn-orig-ctx") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.yarn_orig_ctx = std::stoi(argv[i]); + } else if (arg == "--yarn-ext-factor") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.yarn_ext_factor = std::stof(argv[i]); + } else if (arg == "--yarn-attn-factor") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.yarn_attn_factor = std::stof(argv[i]); + } else if (arg == "--yarn-beta-fast") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.yarn_beta_fast = std::stof(argv[i]); + } else if (arg == "--yarn-beta-slow") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.yarn_beta_slow = std::stof(argv[i]); + } else if (arg == "--samplers") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.samplers_sequence = parse_samplers_input(argv[i]); + } else if (arg == "--sampling-seq") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.samplers_sequence = argv[i]; + } else if (arg == "--top-p") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.top_p = std::stof(argv[i]); + } else if (arg == "--min-p") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.min_p = std::stof(argv[i]); + } else if (arg == "--temp") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.temp = std::stof(argv[i]); + sparams.temp = std::max(sparams.temp, 0.0f); + } else if (arg == "--tfs") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.tfs_z = std::stof(argv[i]); + } else if (arg == "--typical") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.typical_p = std::stof(argv[i]); + } else if (arg == "--repeat-last-n") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.penalty_last_n = std::stoi(argv[i]); + sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n); + } else if (arg == "--repeat-penalty") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.penalty_repeat = std::stof(argv[i]); + } else if (arg == "--frequency-penalty") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.penalty_freq = std::stof(argv[i]); + } else if (arg == "--presence-penalty") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.penalty_present = std::stof(argv[i]); + } else if (arg == "--mirostat") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.mirostat = std::stoi(argv[i]); + } else if (arg == "--mirostat-lr") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.mirostat_eta = std::stof(argv[i]); + } else if (arg == "--mirostat-ent") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.mirostat_tau = std::stof(argv[i]); + } else if (arg == "--cfg-negative-prompt") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.cfg_negative_prompt = argv[i]; + } else if (arg == "--cfg-negative-prompt-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(sparams.cfg_negative_prompt)); + if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') { + sparams.cfg_negative_prompt.pop_back(); + } + } else if (arg == "--cfg-scale") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.cfg_scale = std::stof(argv[i]); + } else if (arg == "-b" || arg == "--batch-size") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_batch = std::stoi(argv[i]); + } else if (arg == "--keep") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_keep = std::stoi(argv[i]); + } else if (arg == "--draft") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_draft = std::stoi(argv[i]); + } else if (arg == "--chunks") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_chunks = std::stoi(argv[i]); + } else if (arg == "-np" || arg == "--parallel") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_parallel = std::stoi(argv[i]); + } else if (arg == "-ns" || arg == "--sequences") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_sequences = std::stoi(argv[i]); + } else if (arg == "--p-accept" || arg == "-pa") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.p_accept = std::stof(argv[i]); + } else if (arg == "--p-split" || arg == "-ps") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.p_split = std::stof(argv[i]); + } else if (arg == "-m" || arg == "--model") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.model = argv[i]; + } else if (arg == "-md" || arg == "--model-draft") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.model_draft = argv[i]; + } else if (arg == "-a" || arg == "--alias") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.model_alias = argv[i]; + } else if (arg == "--lora") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_adapter.push_back(std::make_tuple(argv[i], 1.0f)); + params.use_mmap = false; + } else if (arg == "--lora-scaled") { + if (++i >= argc) { + invalid_param = true; + break; + } + const char * lora_adapter = argv[i]; + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_adapter.push_back(std::make_tuple(lora_adapter, std::stof(argv[i]))); + params.use_mmap = false; + } else if (arg == "--lora-base") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_base = argv[i]; + } else if (arg == "--mmproj") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.mmproj = argv[i]; + } else if (arg == "--image") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.image = argv[i]; + } else if (arg == "-i" || arg == "--interactive") { + params.interactive = true; + } else if (arg == "--embedding") { + params.embedding = true; + } else if (arg == "--interactive-first") { + params.interactive_first = true; + } else if (arg == "-ins" || arg == "--instruct") { + params.instruct = true; + } else if (arg == "-cml" || arg == "--chatml") { + params.chatml = true; + } else if (arg == "--infill") { + params.infill = true; + } else if (arg == "-dkvc" || arg == "--dump-kv-cache") { + params.dump_kv_cache = true; + } else if (arg == "-nkvo" || arg == "--no-kv-offload") { + params.no_kv_offload = true; + } else if (arg == "-ctk" || arg == "--cache-type-k") { + params.cache_type_k = argv[++i]; + } else if (arg == "-ctv" || arg == "--cache-type-v") { + params.cache_type_v = argv[++i]; + } else if (arg == "--multiline-input") { + params.multiline_input = true; + } else if (arg == "--simple-io") { + params.simple_io = true; + } else if (arg == "-cb" || arg == "--cont-batching") { + params.cont_batching = true; + } else if (arg == "--color") { + params.use_color = true; + } else if (arg == "--mlock") { + params.use_mlock = true; + } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { + if (++i >= argc) { + invalid_param = true; + break; + } +#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD + params.n_gpu_layers = std::stoi(argv[i]); +#else + fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); + fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); +#endif + } else if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") { + if (++i >= argc) { + invalid_param = true; + break; + } +#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD + params.n_gpu_layers_draft = std::stoi(argv[i]); +#else + fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n"); + fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); +#endif + } else if (arg == "--main-gpu" || arg == "-mg") { + if (++i >= argc) { + invalid_param = true; + break; + } +#ifdef GGML_USE_CUBLAS + params.main_gpu = std::stoi(argv[i]); +#else + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n"); +#endif + } else if (arg == "--tensor-split" || arg == "-ts") { + if (++i >= argc) { + invalid_param = true; + break; + } +#ifdef GGML_USE_CUBLAS + std::string arg_next = argv[i]; + + // split string by , and / + const std::regex regex{R"([,/]+)"}; + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES); + + for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) { + if (i < split_arg.size()) { + params.tensor_split[i] = std::stof(split_arg[i]); + } else { + params.tensor_split[i] = 0.0f; + } + } +#else + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n"); +#endif // GGML_USE_CUBLAS + } else if (arg == "--no-mul-mat-q" || arg == "-nommq") { +#ifdef GGML_USE_CUBLAS + params.mul_mat_q = false; +#else + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n"); +#endif // GGML_USE_CUBLAS + } else if (arg == "--no-mmap") { + params.use_mmap = false; + } else if (arg == "--numa") { + params.numa = true; + } else if (arg == "--verbose-prompt") { + params.verbose_prompt = true; + } else if (arg == "-r" || arg == "--reverse-prompt") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.antiprompt.push_back(argv[i]); + } else if (arg == "-ld" || arg == "--logdir") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.logdir = argv[i]; + + if (params.logdir.back() != DIRECTORY_SEPARATOR) { + params.logdir += DIRECTORY_SEPARATOR; + } + } else if (arg == "--perplexity" || arg == "--all-logits") { + params.logits_all = true; + } else if (arg == "--ppl-stride") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.ppl_stride = std::stoi(argv[i]); + } else if (arg == "--ppl-output-type") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.ppl_output_type = std::stoi(argv[i]); + } else if (arg == "--hellaswag") { + params.hellaswag = true; + } else if (arg == "--hellaswag-tasks") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.hellaswag_tasks = std::stoi(argv[i]); + } else if (arg == "--ignore-eos") { + params.ignore_eos = true; + } else if (arg == "--no-penalize-nl") { + sparams.penalize_nl = false; + } else if (arg == "-l" || arg == "--logit-bias") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::stringstream ss(argv[i]); + llama_token key; + char sign; + std::string value_str; + try { + if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { + sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + } else { + throw std::exception(); + } + } catch (const std::exception&) { + invalid_param = true; + break; + } + } else if (arg == "-h" || arg == "--help") { + return false; + + } else if (arg == "--version") { + fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); + fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); + exit(0); + } else if (arg == "--random-prompt") { + params.random_prompt = true; + } else if (arg == "--in-prefix-bos") { + params.input_prefix_bos = true; + } else if (arg == "--in-prefix") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.input_prefix = argv[i]; + } else if (arg == "--in-suffix") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.input_suffix = argv[i]; + } else if (arg == "--grammar") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.grammar = argv[i]; + } else if (arg == "--grammar-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(sparams.grammar) + ); + } else if (arg == "--override-kv") { + if (++i >= argc) { + invalid_param = true; + break; + } + char * sep = strchr(argv[i], '='); + if (sep == nullptr || sep - argv[i] >= 128) { + fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]); + invalid_param = true; + break; + } + struct llama_model_kv_override kvo; + std::strncpy(kvo.key, argv[i], sep - argv[i]); + kvo.key[sep - argv[i]] = 0; + sep++; + if (strncmp(sep, "int:", 4) == 0) { + sep += 4; + kvo.tag = LLAMA_KV_OVERRIDE_INT; + kvo.int_value = std::atol(sep); + } else if (strncmp(sep, "float:", 6) == 0) { + sep += 6; + kvo.tag = LLAMA_KV_OVERRIDE_FLOAT; + kvo.float_value = std::atof(sep); + } else if (strncmp(sep, "bool:", 5) == 0) { + sep += 5; + kvo.tag = LLAMA_KV_OVERRIDE_BOOL; + if (std::strcmp(sep, "true") == 0) { + kvo.bool_value = true; + } else if (std::strcmp(sep, "false") == 0) { + kvo.bool_value = false; + } else { + fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]); + invalid_param = true; + break; + } + } else { + fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); + invalid_param = true; + break; + } + params.kv_overrides.push_back(kvo); +#ifndef LOG_DISABLE_LOGS + // Parse args for logging parameters + } else if ( log_param_single_parse( argv[i] ) ) { + // Do nothing, log_param_single_parse automatically does it's thing + // and returns if a match was found and parsed. + } else if ( log_param_pair_parse( /*check_but_dont_parse*/ true, argv[i] ) ) { + // We have a matching known parameter requiring an argument, + // now we need to check if there is anything after this argv + // and flag invalid_param or parse it. + if (++i >= argc) { + invalid_param = true; + break; + } + if( !log_param_pair_parse( /*check_but_dont_parse*/ false, argv[i-1], argv[i]) ) { + invalid_param = true; + break; + } + // End of Parse args for logging parameters +#endif // LOG_DISABLE_LOGS + } else { + throw std::invalid_argument("error: unknown argument: " + arg); + } + } + if (invalid_param) { + throw std::invalid_argument("error: invalid parameter for argument: " + arg); + } + if (params.prompt_cache_all && + (params.interactive || params.interactive_first || + params.instruct)) { + + throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); + } + + if (params.escape) { + process_escapes(params.prompt); + process_escapes(params.input_prefix); + process_escapes(params.input_suffix); + process_escapes(sparams.cfg_negative_prompt); + for (auto & antiprompt : params.antiprompt) { + process_escapes(antiprompt); + } + } + + if (!params.kv_overrides.empty()) { + params.kv_overrides.emplace_back(llama_model_kv_override()); + params.kv_overrides.back().key[0] = 0; + } + + return true; +} + +void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { + const llama_sampling_params & sparams = params.sparams; + + printf("\n"); + printf("usage: %s [options]\n", argv[0]); + printf("\n"); + printf("options:\n"); + printf(" -h, --help show this help message and exit\n"); + printf(" --version show version and build info\n"); + printf(" -i, --interactive run in interactive mode\n"); + printf(" --interactive-first run in interactive mode and wait for input right away\n"); + printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); + printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n"); + printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); + printf(" -r PROMPT, --reverse-prompt PROMPT\n"); + printf(" halt generation at PROMPT, return control in interactive mode\n"); + printf(" (can be specified more than once for multiple prompts).\n"); + printf(" --color colorise output to distinguish prompt and user input from generations\n"); + printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); + printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads); + printf(" -tb N, --threads-batch N\n"); + printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n"); + printf(" -p PROMPT, --prompt PROMPT\n"); + printf(" prompt to start generation with (default: empty)\n"); + printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); + printf(" --prompt-cache FNAME file to cache prompt state for faster startup (default: none)\n"); + printf(" --prompt-cache-all if specified, saves user input and generations to cache as well.\n"); + printf(" not supported with --interactive or other interactive options\n"); + printf(" --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n"); + printf(" --random-prompt start with a randomized prompt.\n"); + printf(" --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n"); + printf(" --in-prefix STRING string to prefix user inputs with (default: empty)\n"); + printf(" --in-suffix STRING string to suffix after user inputs with (default: empty)\n"); + printf(" -f FNAME, --file FNAME\n"); + printf(" prompt file to start generation.\n"); + printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); + printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); + printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); + printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n"); + printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str()); + printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); + printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); + printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); + printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); + printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); + printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); + printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat); + printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present); + printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq); + printf(" --mirostat N use Mirostat sampling.\n"); + printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); + printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat); + printf(" --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)sparams.mirostat_eta); + printf(" --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)sparams.mirostat_tau); + printf(" -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n"); + printf(" modifies the likelihood of token appearing in the completion,\n"); + printf(" i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); + printf(" or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); + printf(" --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n"); + printf(" --grammar-file FNAME file to read grammar from\n"); + printf(" --cfg-negative-prompt PROMPT\n"); + printf(" negative prompt to use for guidance. (default: empty)\n"); + printf(" --cfg-negative-prompt-file FNAME\n"); + printf(" negative prompt file to use for guidance. (default: empty)\n"); + printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale); + printf(" --rope-scaling {none,linear,yarn}\n"); + printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n"); + printf(" --rope-scale N RoPE context scaling factor, expands context by a factor of N\n"); + printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n"); + printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n"); + printf(" --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)\n"); + printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n"); + printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); + printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); + printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); + printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); + printf(" --no-penalize-nl do not penalize newline token\n"); + printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp); + printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n"); + printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); + printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); + printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); + printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft); + printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); + printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel); + printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); + printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept); + printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); + printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); + printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); + printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n"); + if (llama_mlock_supported()) { + printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); + } + if (llama_mmap_supported()) { + printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); + } + printf(" --numa attempt optimizations that help on some NUMA systems\n"); + printf(" if run without this previously, it is recommended to drop the system page cache before using this\n"); + printf(" see https://github.com/ggerganov/llama.cpp/issues/1437\n"); +#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD + printf(" -ngl N, --n-gpu-layers N\n"); + printf(" number of layers to store in VRAM\n"); + printf(" -ngld N, --n-gpu-layers-draft N\n"); + printf(" number of layers to store in VRAM for the draft model\n"); + printf(" -ts SPLIT --tensor-split SPLIT\n"); + printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); + printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n"); +#ifdef GGML_USE_CUBLAS + printf(" -nommq, --no-mul-mat-q\n"); + printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n"); + printf(" Not recommended since this is both slower and uses more VRAM.\n"); +#endif // GGML_USE_CUBLAS +#endif + printf(" --verbose-prompt print prompt before generation\n"); + printf(" -dkvc, --dump-kv-cache\n"); + printf(" verbose print of the KV cache\n"); + printf(" -nkvo, --no-kv-offload\n"); + printf(" disable KV offload\n"); + printf(" -ctk TYPE, --cache-type-k TYPE\n"); + printf(" KV cache data type for K (default: %s)\n", params.cache_type_k.c_str()); + printf(" -ctv TYPE, --cache-type-v TYPE\n"); + printf(" KV cache data type for V (default: %s)\n", params.cache_type_v.c_str()); + printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); + printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); + printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n"); + printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); + printf(" -m FNAME, --model FNAME\n"); + printf(" model path (default: %s)\n", params.model.c_str()); + printf(" -md FNAME, --model-draft FNAME\n"); + printf(" draft model for speculative decoding (default: %s)\n", params.model.c_str()); + printf(" -ld LOGDIR, --logdir LOGDIR\n"); + printf(" path under which to save YAML logs (no logging if unset)\n"); + printf(" --override-kv KEY=TYPE:VALUE\n"); + printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); + printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); + printf("\n"); +#ifndef LOG_DISABLE_LOGS + log_print_usage(); +#endif // LOG_DISABLE_LOGS +} + +std::string get_system_info(const gpt_params & params) { + std::ostringstream os; + + os << "system_info: n_threads = " << params.n_threads; + if (params.n_threads_batch != -1) { + os << " (n_threads_batch = " << params.n_threads_batch << ")"; + } + os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info(); + + return os.str(); +} + +std::string gpt_random_prompt(std::mt19937 & rng) { + const int r = rng() % 10; + switch (r) { + case 0: return "So"; + case 1: return "Once upon a time"; + case 2: return "When"; + case 3: return "The"; + case 4: return "After"; + case 5: return "If"; + case 6: return "import"; + case 7: return "He"; + case 8: return "She"; + case 9: return "They"; + } + + GGML_UNREACHABLE(); +} + +// +// String parsing +// + +std::string parse_samplers_input(std::string input) { + std::string output = ""; + // since samplers names are written multiple ways + // make it ready for both system names and input names + std::unordered_map samplers_symbols { + {"top_k", 'k'}, + {"top-k", 'k'}, + {"top_p", 'p'}, + {"top-p", 'p'}, + {"nucleus", 'p'}, + {"typical_p", 'y'}, + {"typical-p", 'y'}, + {"typical", 'y'}, + {"min_p", 'm'}, + {"min-p", 'm'}, + {"tfs_z", 'f'}, + {"tfs-z", 'f'}, + {"tfs", 'f'}, + {"temp", 't'}, + {"temperature",'t'} + }; + // expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p" + size_t separator = input.find(';'); + while (separator != input.npos) { + std::string name = input.substr(0,separator); + input = input.substr(separator+1); + separator = input.find(';'); + + if (samplers_symbols.find(name) != samplers_symbols.end()) { + output += samplers_symbols[name]; + } + } + if (samplers_symbols.find(input) != samplers_symbols.end()) { + output += samplers_symbols[input]; + } + return output; +} + +// +// Model utils +// + +struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) { + auto mparams = llama_model_default_params(); + + if (params.n_gpu_layers != -1) { + mparams.n_gpu_layers = params.n_gpu_layers; + } + mparams.main_gpu = params.main_gpu; + mparams.tensor_split = params.tensor_split; + mparams.use_mmap = params.use_mmap; + mparams.use_mlock = params.use_mlock; + if (params.kv_overrides.empty()) { + mparams.kv_overrides = NULL; + } else { + GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key"); + mparams.kv_overrides = params.kv_overrides.data(); + } + + return mparams; +} + +static ggml_type kv_cache_type_from_str(const std::string & s) { + if (s == "f16") { + return GGML_TYPE_F16; + } + if (s == "q8_0") { + return GGML_TYPE_Q8_0; + } + if (s == "q4_0") { + return GGML_TYPE_Q4_0; + } + if (s == "q4_1") { + return GGML_TYPE_Q4_1; + } + if (s == "q5_0") { + return GGML_TYPE_Q5_0; + } + if (s == "q5_1") { + return GGML_TYPE_Q5_1; + } + + throw std::runtime_error("Invalid cache type: " + s); +} + +struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { + auto cparams = llama_context_default_params(); + + cparams.n_ctx = params.n_ctx; + cparams.n_batch = params.n_batch; + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + cparams.mul_mat_q = params.mul_mat_q; + cparams.seed = params.seed; + cparams.logits_all = params.logits_all; + cparams.embedding = params.embedding; + cparams.rope_scaling_type = params.rope_scaling_type; + cparams.rope_freq_base = params.rope_freq_base; + cparams.rope_freq_scale = params.rope_freq_scale; + cparams.yarn_ext_factor = params.yarn_ext_factor; + cparams.yarn_attn_factor = params.yarn_attn_factor; + cparams.yarn_beta_fast = params.yarn_beta_fast; + cparams.yarn_beta_slow = params.yarn_beta_slow; + cparams.yarn_orig_ctx = params.yarn_orig_ctx; + cparams.offload_kqv = !params.no_kv_offload; + + cparams.type_k = kv_cache_type_from_str(params.cache_type_k); + cparams.type_v = kv_cache_type_from_str(params.cache_type_v); + + return cparams; +} + +void llama_batch_clear(struct llama_batch & batch) { + batch.n_tokens = 0; +} + +void llama_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + const std::vector & seq_ids, + bool logits) { + batch.token [batch.n_tokens] = id; + batch.pos [batch.n_tokens] = pos; + batch.n_seq_id[batch.n_tokens] = seq_ids.size(); + for (size_t i = 0; i < seq_ids.size(); ++i) { + batch.seq_id[batch.n_tokens][i] = seq_ids[i]; + } + batch.logits [batch.n_tokens] = logits; + + batch.n_tokens++; +} + +std::tuple llama_init_from_gpt_params(gpt_params & params) { + auto mparams = llama_model_params_from_gpt_params(params); + + llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); + if (model == NULL) { + fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); + return std::make_tuple(nullptr, nullptr); + } + + auto cparams = llama_context_params_from_gpt_params(params); + + llama_context * lctx = llama_new_context_with_model(model, cparams); + if (lctx == NULL) { + fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + + for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { + const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]); + float lora_scale = std::get<1>(params.lora_adapter[i]); + int err = llama_model_apply_lora_from_file(model, + lora_adapter.c_str(), + lora_scale, + ((i > 0) || params.lora_base.empty()) + ? NULL + : params.lora_base.c_str(), + params.n_threads); + if (err != 0) { + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + llama_free(lctx); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + } + + if (params.ignore_eos) { + params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + } + + { + LOG("warming up the model with an empty run\n"); + + std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; + llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); + llama_kv_cache_clear(lctx); + llama_reset_timings(lctx); + } + + return std::make_tuple(model, lctx); +} + +// +// Vocab utils +// + +std::vector llama_tokenize( + const struct llama_context * ctx, + const std::string & text, + bool add_bos, + bool special) { + return llama_tokenize(llama_get_model(ctx), text, add_bos, special); +} + +std::vector llama_tokenize( + const struct llama_model * model, + const std::string & text, + bool add_bos, + bool special) { + // upper limit for the number of tokens + int n_tokens = text.length() + add_bos; + std::vector result(n_tokens); + n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + return result; +} + +std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) { + std::vector result(8, 0); + const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + + return std::string(result.data(), result.size()); +} + +std::string llama_detokenize_spm(llama_context * ctx, const std::vector & tokens) { + const llama_token bos_id = llama_token_bos(llama_get_model(ctx)); + + std::string piece; + std::string result; + + for (size_t i = 0; i < tokens.size(); ++i) { + piece = llama_token_to_piece(ctx, tokens[i]); + + // remove the leading space of the first non-BOS token + if (((tokens[0] == bos_id && i == 1) || (tokens[0] != bos_id && i == 0)) && piece[0] == ' ') { + piece = piece.substr(1); + } + + result += piece; + } + + return result; +} + +std::string llama_detokenize_bpe(llama_context * ctx, const std::vector & tokens) { + std::string piece; + std::string result; + + for (size_t i = 0; i < tokens.size(); ++i) { + piece = llama_token_to_piece(ctx, tokens[i]); + + result += piece; + } + + // NOTE: the original tokenizer decodes bytes after collecting the pieces. + return result; +} + +bool llama_should_add_bos_token(const llama_model * model) { + const int add_bos = llama_add_bos_token(model); + + return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); +} + +// +// YAML utils +// + +// returns true if successful, false otherwise +bool create_directory_with_parents(const std::string & path) { +#ifdef _WIN32 + std::wstring_convert> converter; + std::wstring wpath = converter.from_bytes(path); + + // if the path already exists, check whether it's a directory + const DWORD attributes = GetFileAttributesW(wpath.c_str()); + if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return true; + } + + size_t pos_slash = 0; + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { + const std::wstring subpath = wpath.substr(0, pos_slash); + const wchar_t * test = subpath.c_str(); + + const bool success = CreateDirectoryW(test, NULL); + if (!success) { + const DWORD error = GetLastError(); + + // if the path already exists, ensure that it's a directory + if (error == ERROR_ALREADY_EXISTS) { + const DWORD attributes = GetFileAttributesW(subpath.c_str()); + if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return false; + } + } else { + return false; + } + } + + pos_slash += 1; + } + + return true; +#else + // if the path already exists, check whether it's a directory + struct stat info; + if (stat(path.c_str(), &info) == 0) { + return S_ISDIR(info.st_mode); + } + + size_t pos_slash = 1; // skip leading slashes for directory creation + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) { + const std::string subpath = path.substr(0, pos_slash); + struct stat info; + + // if the path already exists, ensure that it's a directory + if (stat(subpath.c_str(), &info) == 0) { + if (!S_ISDIR(info.st_mode)) { + return false; + } + } else { + // create parent directories + const int ret = mkdir(subpath.c_str(), 0755); + if (ret != 0) { + return false; + } + } + + pos_slash += 1; + } + + return true; +#endif // _WIN32 +} + +void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector & data) { + if (data.empty()) { + fprintf(stream, "%s:\n", prop_name); + return; + } + + fprintf(stream, "%s: [", prop_name); + for (size_t i = 0; i < data.size() - 1; ++i) { + fprintf(stream, "%e, ", data[i]); + } + fprintf(stream, "%e]\n", data.back()); +} + +void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector & data) { + if (data.empty()) { + fprintf(stream, "%s:\n", prop_name); + return; + } + + fprintf(stream, "%s: [", prop_name); + for (size_t i = 0; i < data.size() - 1; ++i) { + fprintf(stream, "%d, ", data[i]); + } + fprintf(stream, "%d]\n", data.back()); +} + +void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data) { + std::string data_str(data == NULL ? "" : data); + + if (data_str.empty()) { + fprintf(stream, "%s:\n", prop_name); + return; + } + + size_t pos_start = 0; + size_t pos_found = 0; + + if (!data_str.empty() && (std::isspace(data_str[0]) || std::isspace(data_str.back()))) { + data_str = std::regex_replace(data_str, std::regex("\n"), "\\n"); + data_str = std::regex_replace(data_str, std::regex("\""), "\\\""); + data_str = std::regex_replace(data_str, std::regex(R"(\\[^n"])"), R"(\$&)"); + data_str = "\"" + data_str + "\""; + fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); + return; + } + + if (data_str.find('\n') == std::string::npos) { + fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); + return; + } + + fprintf(stream, "%s: |\n", prop_name); + while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) { + fprintf(stream, " %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str()); + pos_start = pos_found + 1; + } +} + +std::string get_sortable_timestamp() { + using clock = std::chrono::system_clock; + + const clock::time_point current_time = clock::now(); + const time_t as_time_t = clock::to_time_t(current_time); + char timestamp_no_ns[100]; + std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t)); + + const int64_t ns = std::chrono::duration_cast( + current_time.time_since_epoch() % 1000000000).count(); + char timestamp_ns[11]; + snprintf(timestamp_ns, 11, "%09" PRId64, ns); + + return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns); +} + +void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx, + const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { + const llama_sampling_params & sparams = params.sparams; + + fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT); + fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER); + fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false"); + fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false"); + fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false"); + fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false"); + fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false"); + fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false"); + fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); + fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false"); + fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false"); + fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false"); + fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false"); + fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false"); + fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false"); + fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false"); + fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false"); + fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); + fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false"); + fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false"); + +#ifdef NDEBUG + fprintf(stream, "debug: false\n"); +#else + fprintf(stream, "debug: true\n"); +#endif // NDEBUG + + fprintf(stream, "model_desc: %s\n", model_desc); + fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx))); + +#ifdef __OPTIMIZE__ + fprintf(stream, "optimize: true\n"); +#else + fprintf(stream, "optimize: false\n"); +#endif // __OPTIMIZE__ + + fprintf(stream, "time: %s\n", timestamp.c_str()); + + fprintf(stream, "\n"); + fprintf(stream, "###############\n"); + fprintf(stream, "# User Inputs #\n"); + fprintf(stream, "###############\n"); + fprintf(stream, "\n"); + + fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); + fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); + dump_string_yaml_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str()); + fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale); + fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); + fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); + fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); + fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); + fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); + fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); + dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str()); + fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); + fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); + fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); + + const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx))); + const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; + fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); + + dump_string_yaml_multiline(stream, "in_prefix", params.input_prefix.c_str()); + fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false"); + dump_string_yaml_multiline(stream, "in_suffix", params.input_prefix.c_str()); + fprintf(stream, "instruct: %s # default: false\n", params.instruct ? "true" : "false"); + fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false"); + fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false"); + fprintf(stream, "keep: %d # default: 0\n", params.n_keep); + fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); + + fprintf(stream, "logit_bias:\n"); + for (std::pair lb : sparams.logit_bias) { + if (ignore_eos && lb.first == logit_bias_eos->first) { + continue; + } + fprintf(stream, " %d: %f", lb.first, lb.second); + } + + fprintf(stream, "lora:\n"); + for (std::tuple la : params.lora_adapter) { + if (std::get<1>(la) != 1.0f) { + continue; + } + fprintf(stream, " - %s\n", std::get<0>(la).c_str()); + } + fprintf(stream, "lora_scaled:\n"); + for (std::tuple la : params.lora_adapter) { + if (std::get<1>(la) == 1.0f) { + continue; + } + fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la)); + } + fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); + fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); + fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); + fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); + fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); + fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); + fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str()); + fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); + fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); + fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); + fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); + fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); + fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); + fprintf(stream, "no_mul_mat_q: %s # default: false\n", !params.mul_mat_q ? "true" : "false"); + fprintf(stream, "no_penalize_nl: %s # default: false\n", !sparams.penalize_nl ? "true" : "false"); + fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false"); + fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); + fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); + fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present); + dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str()); + fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); + fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); + fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); + dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens); + fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false"); + fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat); + + fprintf(stream, "reverse_prompt:\n"); + for (std::string ap : params.antiprompt) { + size_t pos = 0; + while ((pos = ap.find('\n', pos)) != std::string::npos) { + ap.replace(pos, 1, "\\n"); + pos += 1; + } + + fprintf(stream, " - %s\n", ap.c_str()); + } + + fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base); + fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); + fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed); + fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); + fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); + fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); + + const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES); + dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector); + + fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); + fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency()); + fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); + fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); + fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); + fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); + fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); +} + +// +// KV cache utils +// + +void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) { + static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+"; + + printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d", + view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx); + + llama_kv_cache_view_cell * c_curr = view.cells; + llama_seq_id * cs_curr = view.cells_sequences; + + for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { + if (i % row_size == 0) { + printf("\n%5d: ", i); + } + int seq_count = 0; + for (int j = 0; j < view.n_max_seq; j++) { + if (cs_curr[j] >= 0) { seq_count++; } + } + putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]); + } + + printf("\n=== Done dumping\n"); +} + +void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) { + static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n", + view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx); + + std::unordered_map seqs; + llama_kv_cache_view_cell * c_curr = view.cells; + llama_seq_id * cs_curr = view.cells_sequences; + + for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { + for (int j = 0; j < view.n_max_seq; j++) { + if (cs_curr[j] < 0) { continue; } + if (seqs.find(cs_curr[j]) == seqs.end()) { + if (seqs.size() + 1 >= sizeof(slot_chars)) { break; } + seqs[cs_curr[j]] = seqs.size(); + } + } + if (seqs.size() + 1 >= sizeof(slot_chars)) { break; } + } + + printf("=== Sequence legend: "); + for (const auto & it : seqs) { + printf("%zu=%d, ", it.second, it.first); + } + printf("'+'=other sequence ids"); + + c_curr = view.cells; + cs_curr = view.cells_sequences; + for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { + if (i % row_size == 0) { + printf("\n%5d: ", i); + } + for (int j = 0; j < view.n_max_seq; j++) { + if (cs_curr[j] >= 0) { + const auto & it = seqs.find(cs_curr[j]); + putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+'); + } else { + putchar('.'); + } + } + putchar(' '); + } + + printf("\n=== Done dumping\n"); +} diff --git a/applications/src/llama/common/common.h b/applications/src/llama/common/common.h new file mode 100644 index 000000000..e87ce1133 --- /dev/null +++ b/applications/src/llama/common/common.h @@ -0,0 +1,242 @@ +// Various helper functions and utilities + +#pragma once + +#include "llama.h" + +#include "sampling.h" + +#define LOG_NO_FILE_LINE_FUNCTION +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#define DIRECTORY_SEPARATOR '\\' +#else +#define DIRECTORY_SEPARATOR '/' +#endif // _WIN32 + +#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0) +#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0) + +#define print_build_info() do { \ + fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \ + fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ +} while(0) + +// build info +extern int LLAMA_BUILD_NUMBER; +extern char const *LLAMA_COMMIT; +extern char const *LLAMA_COMPILER; +extern char const *LLAMA_BUILD_TARGET; + +// +// CLI argument parsing +// +int32_t get_num_physical_cores(); + +struct gpt_params { + uint32_t seed = -1; // RNG seed + + int32_t n_threads = get_num_physical_cores(); + int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) + int32_t n_predict = -1; // new tokens to predict + int32_t n_ctx = 512; // context size + int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_draft = 16; // number of tokens to draft during speculative decoding + int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) + int32_t n_parallel = 1; // number of parallel sequences to decode + int32_t n_sequences = 1; // number of sequences to decode + float p_accept = 0.5f; // speculative decoding accept probability + float p_split = 0.1f; // speculative decoding split probability + int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) + int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs + int32_t n_beams = 0; // if non-zero then use beam search of given width. + float rope_freq_base = 0.0f; // RoPE base frequency + float rope_freq_scale = 0.0f; // RoPE frequency scaling factor + float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor + float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor + float yarn_beta_fast = 32.0f; // YaRN low correction dim + float yarn_beta_slow = 1.0f; // YaRN high correction dim + int32_t yarn_orig_ctx = 0; // YaRN original context length + int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment + // pinging @cebtenzzre + + // // sampling parameters + struct llama_sampling_params sparams; + + std::string model = "models/7B/ggml-model-f16.gguf"; // model path + std::string model_draft = ""; // draft model for speculative decoding + std::string model_alias = "unknown"; // model alias + std::string prompt = ""; + std::string prompt_file = ""; // store the external prompt file name + std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state + std::string input_prefix = ""; // string to prefix user inputs with + std::string input_suffix = ""; // string to suffix user inputs with + std::vector antiprompt; // string upon seeing which more user input is prompted + std::string logdir = ""; // directory in which to save YAML log files + + std::vector kv_overrides; + + // TODO: avoid tuple, use struct + std::vector> lora_adapter; // lora adapter path with user defined scale + std::string lora_base = ""; // base model path for the lora adapter + + int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. + int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line + // (which is more convenient to use for plotting) + // + bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt + size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score + + bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS + bool random_prompt = false; // do not randomize prompt if none provided + bool use_color = false; // use color to distinguish generations and inputs + bool interactive = false; // interactive mode + bool chatml = false; // chatml mode (used for models trained on chatml syntax) + bool prompt_cache_all = false; // save user input and generations to prompt cache + bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it + + bool embedding = false; // get only sentence embedding + bool escape = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\" + bool interactive_first = false; // wait for user input immediately + bool multiline_input = false; // reverse the usage of `\` + bool simple_io = false; // improves compatibility with subprocesses and limited consoles + bool cont_batching = false; // insert new sequences for decoding on-the-fly + + bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix + bool ignore_eos = false; // ignore generated EOS tokens + bool instruct = false; // instruction mode (used for Alpaca models) + bool logits_all = false; // return logits for all tokens in the batch + bool use_mmap = true; // use mmap for faster loads + bool use_mlock = false; // use mlock to keep model in memory + bool numa = false; // attempt optimizations that help on some NUMA systems + bool verbose_prompt = false; // print prompt tokens before generation + bool infill = false; // use infill mode + bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes + bool no_kv_offload = false; // disable KV offloading + + std::string cache_type_k = "f16"; // KV cache data type for the K + std::string cache_type_v = "f16"; // KV cache data type for the V + + // multimodal models (see examples/llava) + std::string mmproj = ""; // path to multimodal projector + std::string image = ""; // path to an image file +}; + +bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params); + +bool gpt_params_parse(int argc, char ** argv, gpt_params & params); + +void gpt_print_usage(int argc, char ** argv, const gpt_params & params); + +std::string get_system_info(const gpt_params & params); + +std::string gpt_random_prompt(std::mt19937 & rng); + +void process_escapes(std::string& input); + +// +// String parsing +// + +std::string parse_samplers_input(std::string input); + +// +// Model utils +// + +// TODO: avoid tuplue, use struct +std::tuple llama_init_from_gpt_params(gpt_params & params); + +struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); +struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); + +// Batch utils + +void llama_batch_clear(struct llama_batch & batch); + +void llama_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + const std::vector & seq_ids, + bool logits); + +// +// Vocab utils +// + +// tokenizes a string into a vector of tokens +// should work similar to Python's `tokenizer.encode` +std::vector llama_tokenize( + const struct llama_context * ctx, + const std::string & text, + bool add_bos, + bool special = false); + +std::vector llama_tokenize( + const struct llama_model * model, + const std::string & text, + bool add_bos, + bool special = false); + +// tokenizes a token into a piece +// should work similar to Python's `tokenizer.id_to_piece` +std::string llama_token_to_piece( + const struct llama_context * ctx, + llama_token token); + +// TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function +// that takes into account the tokenizer type and decides how to handle the leading space +// +// detokenizes a vector of tokens into a string +// should work similar to Python's `tokenizer.decode` +// removes the leading space from the first non-BOS token +std::string llama_detokenize_spm( + llama_context * ctx, + const std::vector & tokens); + +// detokenizes a vector of tokens into a string +// should work similar to Python's `tokenizer.decode` +std::string llama_detokenize_bpe( + llama_context * ctx, + const std::vector & tokens); + +// Uses the value from the model metadata if possible, otherwise +// defaults to true when model type is SPM, otherwise false. +bool llama_should_add_bos_token(const llama_model * model); + +// +// YAML utils +// + +bool create_directory_with_parents(const std::string & path); +void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector & data); +void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector & data); +void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data); +std::string get_sortable_timestamp(); + +void dump_non_result_info_yaml( + FILE * stream, const gpt_params & params, const llama_context * lctx, + const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); + +// +// KV cache utils +// + +// Dump the KV cache view with the number of sequences per cell. +void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80); + +// Dump the KV cache view showing individual sequences in each cell (long output). +void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); diff --git a/applications/src/llama/common/console.cpp b/applications/src/llama/common/console.cpp new file mode 100644 index 000000000..f65cbc6ed --- /dev/null +++ b/applications/src/llama/common/console.cpp @@ -0,0 +1,501 @@ +#include "console.h" +#include +#include + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#include +#include +#ifndef ENABLE_VIRTUAL_TERMINAL_PROCESSING +#define ENABLE_VIRTUAL_TERMINAL_PROCESSING 0x0004 +#endif +#else +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#define ANSI_COLOR_RED "\x1b[31m" +#define ANSI_COLOR_GREEN "\x1b[32m" +#define ANSI_COLOR_YELLOW "\x1b[33m" +#define ANSI_COLOR_BLUE "\x1b[34m" +#define ANSI_COLOR_MAGENTA "\x1b[35m" +#define ANSI_COLOR_CYAN "\x1b[36m" +#define ANSI_COLOR_RESET "\x1b[0m" +#define ANSI_BOLD "\x1b[1m" + +namespace console { + + // + // Console state + // + + static bool advanced_display = false; + static bool simple_io = true; + static display_t current_display = reset; + + static FILE* out = stdout; + +#if defined (_WIN32) + static void* hConsole; +#else + static FILE* tty = nullptr; + static termios initial_state; +#endif + + // + // Init and cleanup + // + + void init(bool use_simple_io, bool use_advanced_display) { + advanced_display = use_advanced_display; + simple_io = use_simple_io; +#if defined(_WIN32) + // Windows-specific console initialization + DWORD dwMode = 0; + hConsole = GetStdHandle(STD_OUTPUT_HANDLE); + if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) { + hConsole = GetStdHandle(STD_ERROR_HANDLE); + if (hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(hConsole, &dwMode))) { + hConsole = nullptr; + simple_io = true; + } + } + if (hConsole) { + // Check conditions combined to reduce nesting + if (advanced_display && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING) && + !SetConsoleMode(hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { + advanced_display = false; + } + // Set console output codepage to UTF8 + SetConsoleOutputCP(CP_UTF8); + } + HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE); + if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) { + // Set console input codepage to UTF16 + _setmode(_fileno(stdin), _O_WTEXT); + + // Set ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT) + if (simple_io) { + dwMode |= ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT; + } else { + dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT); + } + if (!SetConsoleMode(hConIn, dwMode)) { + simple_io = true; + } + } +#else + // POSIX-specific console initialization + if (!simple_io) { + struct termios new_termios; + tcgetattr(STDIN_FILENO, &initial_state); + new_termios = initial_state; + new_termios.c_lflag &= ~(ICANON | ECHO); + new_termios.c_cc[VMIN] = 1; + new_termios.c_cc[VTIME] = 0; + tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); + + tty = fopen("/dev/tty", "w+"); + if (tty != nullptr) { + out = tty; + } + } + + setlocale(LC_ALL, ""); +#endif + } + + void cleanup() { + // Reset console display + set_display(reset); + +#if !defined(_WIN32) + // Restore settings on POSIX systems + if (!simple_io) { + if (tty != nullptr) { + out = stdout; + fclose(tty); + tty = nullptr; + } + tcsetattr(STDIN_FILENO, TCSANOW, &initial_state); + } +#endif + } + + // + // Display and IO + // + + // Keep track of current display and only emit ANSI code if it changes + void set_display(display_t display) { + if (advanced_display && current_display != display) { + fflush(stdout); + switch(display) { + case reset: + fprintf(out, ANSI_COLOR_RESET); + break; + case prompt: + fprintf(out, ANSI_COLOR_YELLOW); + break; + case user_input: + fprintf(out, ANSI_BOLD ANSI_COLOR_GREEN); + break; + case error: + fprintf(out, ANSI_BOLD ANSI_COLOR_RED); + } + current_display = display; + fflush(out); + } + } + + static char32_t getchar32() { +#if defined(_WIN32) + HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE); + wchar_t high_surrogate = 0; + + while (true) { + INPUT_RECORD record; + DWORD count; + if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) { + return WEOF; + } + + if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) { + wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar; + if (wc == 0) { + continue; + } + + if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + high_surrogate = wc; + continue; + } + if ((wc >= 0xDC00) && (wc <= 0xDFFF)) { // Check if wc is a low surrogate + if (high_surrogate != 0) { // Check if we have a high surrogate + return ((high_surrogate - 0xD800) << 10) + (wc - 0xDC00) + 0x10000; + } + } + + high_surrogate = 0; // Reset the high surrogate + return static_cast(wc); + } + } +#else + wchar_t wc = getwchar(); + if (static_cast(wc) == WEOF) { + return WEOF; + } + +#if WCHAR_MAX == 0xFFFF + if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + wchar_t low_surrogate = getwchar(); + if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate + return (static_cast(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000; + } + } + if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair + return 0xFFFD; // Return the replacement character U+FFFD + } +#endif + + return static_cast(wc); +#endif + } + + static void pop_cursor() { +#if defined(_WIN32) + if (hConsole != NULL) { + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + GetConsoleScreenBufferInfo(hConsole, &bufferInfo); + + COORD newCursorPosition = bufferInfo.dwCursorPosition; + if (newCursorPosition.X == 0) { + newCursorPosition.X = bufferInfo.dwSize.X - 1; + newCursorPosition.Y -= 1; + } else { + newCursorPosition.X -= 1; + } + + SetConsoleCursorPosition(hConsole, newCursorPosition); + return; + } +#endif + putc('\b', out); + } + + static int estimateWidth(char32_t codepoint) { +#if defined(_WIN32) + (void)codepoint; + return 1; +#else + return wcwidth(codepoint); +#endif + } + + static int put_codepoint(const char* utf8_codepoint, size_t length, int expectedWidth) { +#if defined(_WIN32) + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) { + // go with the default + return expectedWidth; + } + COORD initialPosition = bufferInfo.dwCursorPosition; + DWORD nNumberOfChars = length; + WriteConsole(hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL); + + CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; + GetConsoleScreenBufferInfo(hConsole, &newBufferInfo); + + // Figure out our real position if we're in the last column + if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) { + DWORD nNumberOfChars; + WriteConsole(hConsole, &" \b", 2, &nNumberOfChars, NULL); + GetConsoleScreenBufferInfo(hConsole, &newBufferInfo); + } + + int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; + if (width < 0) { + width += newBufferInfo.dwSize.X; + } + return width; +#else + // We can trust expectedWidth if we've got one + if (expectedWidth >= 0 || tty == nullptr) { + fwrite(utf8_codepoint, length, 1, out); + return expectedWidth; + } + + fputs("\033[6n", tty); // Query cursor position + int x1; + int y1; + int x2; + int y2; + int results = 0; + results = fscanf(tty, "\033[%d;%dR", &y1, &x1); + + fwrite(utf8_codepoint, length, 1, tty); + + fputs("\033[6n", tty); // Query cursor position + results += fscanf(tty, "\033[%d;%dR", &y2, &x2); + + if (results != 4) { + return expectedWidth; + } + + int width = x2 - x1; + if (width < 0) { + // Calculate the width considering text wrapping + struct winsize w; + ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); + width += w.ws_col; + } + return width; +#endif + } + + static void replace_last(char ch) { +#if defined(_WIN32) + pop_cursor(); + put_codepoint(&ch, 1, 1); +#else + fprintf(out, "\b%c", ch); +#endif + } + + static void append_utf8(char32_t ch, std::string & out) { + if (ch <= 0x7F) { + out.push_back(static_cast(ch)); + } else if (ch <= 0x7FF) { + out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0xFFFF) { + out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0x10FFFF) { + out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); + out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else { + // Invalid Unicode code point + } + } + + // Helper function to remove the last UTF-8 character from a string + static void pop_back_utf8_char(std::string & line) { + if (line.empty()) { + return; + } + + size_t pos = line.length() - 1; + + // Find the start of the last UTF-8 character (checking up to 4 bytes back) + for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) { + if ((line[pos] & 0xC0) != 0x80) { + break; // Found the start of the character + } + } + line.erase(pos); + } + + static bool readline_advanced(std::string & line, bool multiline_input) { + if (out != stdout) { + fflush(stdout); + } + + line.clear(); + std::vector widths; + bool is_special_char = false; + bool end_of_stream = false; + + char32_t input_char; + while (true) { + fflush(out); // Ensure all output is displayed before waiting for input + input_char = getchar32(); + + if (input_char == '\r' || input_char == '\n') { + break; + } + + if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) { + end_of_stream = true; + break; + } + + if (is_special_char) { + set_display(user_input); + replace_last(line.back()); + is_special_char = false; + } + + if (input_char == '\033') { // Escape sequence + char32_t code = getchar32(); + if (code == '[' || code == 0x1B) { + // Discard the rest of the escape sequence + while ((code = getchar32()) != (char32_t) WEOF) { + if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { + break; + } + } + } + } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace + if (!widths.empty()) { + int count; + do { + count = widths.back(); + widths.pop_back(); + // Move cursor back, print space, and move cursor back again + for (int i = 0; i < count; i++) { + replace_last(' '); + pop_cursor(); + } + pop_back_utf8_char(line); + } while (count == 0 && !widths.empty()); + } + } else { + int offset = line.length(); + append_utf8(input_char, line); + int width = put_codepoint(line.c_str() + offset, line.length() - offset, estimateWidth(input_char)); + if (width < 0) { + width = 0; + } + widths.push_back(width); + } + + if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { + set_display(prompt); + replace_last(line.back()); + is_special_char = true; + } + } + + bool has_more = multiline_input; + if (is_special_char) { + replace_last(' '); + pop_cursor(); + + char last = line.back(); + line.pop_back(); + if (last == '\\') { + line += '\n'; + fputc('\n', out); + has_more = !has_more; + } else { + // llama will just eat the single space, it won't act as a space + if (line.length() == 1 && line.back() == ' ') { + line.clear(); + pop_cursor(); + } + has_more = false; + } + } else { + if (end_of_stream) { + has_more = false; + } else { + line += '\n'; + fputc('\n', out); + } + } + + fflush(out); + return has_more; + } + + static bool readline_simple(std::string & line, bool multiline_input) { +#if defined(_WIN32) + std::wstring wline; + if (!std::getline(std::wcin, wline)) { + // Input stream is bad or EOF received + line.clear(); + GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0); + return false; + } + + int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL); + line.resize(size_needed); + WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL); +#else + if (!std::getline(std::cin, line)) { + // Input stream is bad or EOF received + line.clear(); + return false; + } +#endif + if (!line.empty()) { + char last = line.back(); + if (last == '/') { // Always return control on '/' symbol + line.pop_back(); + return false; + } + if (last == '\\') { // '\\' changes the default action + line.pop_back(); + multiline_input = !multiline_input; + } + } + line += '\n'; + + // By default, continue input if multiline_input is set + return multiline_input; + } + + bool readline(std::string & line, bool multiline_input) { + set_display(user_input); + + if (simple_io) { + return readline_simple(line, multiline_input); + } + return readline_advanced(line, multiline_input); + } + +} diff --git a/applications/src/llama/common/console.h b/applications/src/llama/common/console.h new file mode 100644 index 000000000..ec175269b --- /dev/null +++ b/applications/src/llama/common/console.h @@ -0,0 +1,19 @@ +// Console functions + +#pragma once + +#include + +namespace console { + enum display_t { + reset = 0, + prompt, + user_input, + error + }; + + void init(bool use_simple_io, bool use_advanced_display); + void cleanup(); + void set_display(display_t display); + bool readline(std::string & line, bool multiline_input); +} diff --git a/applications/src/llama/common/grammar-parser.cpp b/applications/src/llama/common/grammar-parser.cpp new file mode 100644 index 000000000..bf89a96f3 --- /dev/null +++ b/applications/src/llama/common/grammar-parser.cpp @@ -0,0 +1,424 @@ +#include "grammar-parser.h" +#include +#include +#include +#include +#include +#include + +namespace grammar_parser { + // NOTE: assumes valid utf8 (but checks for overrun) + // copied from llama.cpp + static std::pair decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t first_byte = static_cast(*src); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = src + len; // may overrun! + const char * pos = src + 1; + for ( ; pos < end && *pos; pos++) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + return std::make_pair(value, pos); + } + + static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { + uint32_t next_id = static_cast(state.symbol_ids.size()); + auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); + return result.first->second; + } + + static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { + uint32_t next_id = static_cast(state.symbol_ids.size()); + state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; + return next_id; + } + + static void add_rule( + parse_state & state, + uint32_t rule_id, + const std::vector & rule) { + if (state.rules.size() <= rule_id) { + state.rules.resize(rule_id + 1); + } + state.rules[rule_id] = rule; + } + + static bool is_word_char(char c) { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); + } + + static std::pair parse_hex(const char * src, int size) { + const char * pos = src; + const char * end = src + size; + uint32_t value = 0; + for ( ; pos < end && *pos; pos++) { + value <<= 4; + char c = *pos; + if ('a' <= c && c <= 'f') { + value += c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + value += c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + value += c - '0'; + } else { + break; + } + } + if (pos != end) { + throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); + } + return std::make_pair(value, pos); + } + + static const char * parse_space(const char * src, bool newline_ok) { + const char * pos = src; + while (*pos == ' ' || *pos == '\t' || *pos == '#' || + (newline_ok && (*pos == '\r' || *pos == '\n'))) { + if (*pos == '#') { + while (*pos && *pos != '\r' && *pos != '\n') { + pos++; + } + } else { + pos++; + } + } + return pos; + } + + static const char * parse_name(const char * src) { + const char * pos = src; + while (is_word_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting name at ") + src); + } + return pos; + } + + static std::pair parse_char(const char * src) { + if (*src == '\\') { + switch (src[1]) { + case 'x': return parse_hex(src + 2, 2); + case 'u': return parse_hex(src + 2, 4); + case 'U': return parse_hex(src + 2, 8); + case 't': return std::make_pair('\t', src + 2); + case 'r': return std::make_pair('\r', src + 2); + case 'n': return std::make_pair('\n', src + 2); + case '\\': + case '"': + case '[': + case ']': + return std::make_pair(src[1], src + 2); + default: + throw std::runtime_error(std::string("unknown escape at ") + src); + } + } else if (*src) { + return decode_utf8(src); + } + throw std::runtime_error("unexpected end of input"); + } + + const char * parse_alternates( + parse_state & state, + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested); + + static const char * parse_sequence( + parse_state & state, + const char * src, + const std::string & rule_name, + std::vector & out_elements, + bool is_nested) { + size_t last_sym_start = out_elements.size(); + const char * pos = src; + while (*pos) { + if (*pos == '"') { // literal string + pos++; + last_sym_start = out_elements.size(); + while (*pos != '"') { + auto char_pair = parse_char(pos); + pos = char_pair.second; + out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '[') { // char range(s) + pos++; + enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; + if (*pos == '^') { + pos++; + start_type = LLAMA_GRETYPE_CHAR_NOT; + } + last_sym_start = out_elements.size(); + while (*pos != ']') { + auto char_pair = parse_char(pos); + pos = char_pair.second; + enum llama_gretype type = last_sym_start < out_elements.size() + ? LLAMA_GRETYPE_CHAR_ALT + : start_type; + + out_elements.push_back({type, char_pair.first}); + if (pos[0] == '-' && pos[1] != ']') { + auto endchar_pair = parse_char(pos + 1); + pos = endchar_pair.second; + out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); + } + } + pos = parse_space(pos + 1, is_nested); + } else if (is_word_char(*pos)) { // rule reference + const char * name_end = parse_name(pos); + uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); + pos = parse_space(name_end, is_nested); + last_sym_start = out_elements.size(); + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + } else if (*pos == '(') { // grouping + // parse nested alternates into synthesized rule + pos = parse_space(pos + 1, true); + uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); + last_sym_start = out_elements.size(); + // output reference to synthesized rule + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + if (*pos != ')') { + throw std::runtime_error(std::string("expecting ')' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator + if (last_sym_start == out_elements.size()) { + throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos); + } + + // apply transformation to previous symbol (last_sym_start to end) according to + // rewrite rules: + // S* --> S' ::= S S' | + // S+ --> S' ::= S S' | S + // S? --> S' ::= S | + uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + std::vector sub_rule; + // add preceding symbol to generated rule + sub_rule.insert( + sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); + if (*pos == '*' || *pos == '+') { + // cause generated rule to recurse + sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + } + // mark start of alternate def + sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + if (*pos == '+') { + // add preceding symbol as alternate only for '+' (otherwise empty) + sub_rule.insert( + sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); + } + sub_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, sub_rule_id, sub_rule); + + // in original rule, replace previous symbol with reference to generated rule + out_elements.resize(last_sym_start); + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + + pos = parse_space(pos + 1, is_nested); + } else { + break; + } + } + return pos; + } + + const char * parse_alternates( + parse_state & state, + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested) { + std::vector rule; + const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); + while (*pos == '|') { + rule.push_back({LLAMA_GRETYPE_ALT, 0}); + pos = parse_space(pos + 1, true); + pos = parse_sequence(state, pos, rule_name, rule, is_nested); + } + rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, rule_id, rule); + return pos; + } + + static const char * parse_rule(parse_state & state, const char * src) { + const char * name_end = parse_name(src); + const char * pos = parse_space(name_end, false); + size_t name_len = name_end - src; + uint32_t rule_id = get_symbol_id(state, src, name_len); + const std::string name(src, name_len); + + if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { + throw std::runtime_error(std::string("expecting ::= at ") + pos); + } + pos = parse_space(pos + 3, true); + + pos = parse_alternates(state, pos, name, rule_id, false); + + if (*pos == '\r') { + pos += pos[1] == '\n' ? 2 : 1; + } else if (*pos == '\n') { + pos++; + } else if (*pos) { + throw std::runtime_error(std::string("expecting newline or end at ") + pos); + } + return parse_space(pos, true); + } + + parse_state parse(const char * src) { + try { + parse_state state; + const char * pos = parse_space(src, true); + while (*pos) { + pos = parse_rule(state, pos); + } + return state; + } catch (const std::exception & err) { + fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + return parse_state(); + } + } + + static void print_grammar_char(FILE * file, uint32_t c) { + if (0x20 <= c && c <= 0x7f) { + fprintf(file, "%c", static_cast(c)); + } else { + // cop out of encoding UTF-8 + fprintf(file, "", c); + } + } + + static bool is_char_element(llama_grammar_element elem) { + switch (elem.type) { + case LLAMA_GRETYPE_CHAR: return true; + case LLAMA_GRETYPE_CHAR_NOT: return true; + case LLAMA_GRETYPE_CHAR_ALT: return true; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + default: return false; + } + } + + static void print_rule_binary(FILE * file, const std::vector & rule) { + for (auto elem : rule) { + switch (elem.type) { + case LLAMA_GRETYPE_END: fprintf(file, "END"); break; + case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; + case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; + case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; + case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; + case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; + } + switch (elem.type) { + case LLAMA_GRETYPE_END: + case LLAMA_GRETYPE_ALT: + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "(%u) ", elem.value); + break; + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ALT: + fprintf(file, "(\""); + print_grammar_char(file, elem.value); + fprintf(file, "\") "); + break; + } + } + fprintf(file, "\n"); + } + + static void print_rule( + FILE * file, + uint32_t rule_id, + const std::vector & rule, + const std::map & symbol_id_names) { + if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { + throw std::runtime_error( + "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); + } + fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); + for (size_t i = 0, end = rule.size() - 1; i < end; i++) { + llama_grammar_element elem = rule[i]; + switch (elem.type) { + case LLAMA_GRETYPE_END: + throw std::runtime_error( + "unexpected end of rule: " + std::to_string(rule_id) + "," + + std::to_string(i)); + case LLAMA_GRETYPE_ALT: + fprintf(file, "| "); + break; + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); + break; + case LLAMA_GRETYPE_CHAR: + fprintf(file, "["); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_NOT: + fprintf(file, "[^"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + fprintf(file, "-"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ALT: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + print_grammar_char(file, elem.value); + break; + } + if (is_char_element(elem)) { + switch (rule[i + 1].type) { + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + break; + default: + fprintf(file, "] "); + } + } + } + fprintf(file, "\n"); + } + + void print_grammar(FILE * file, const parse_state & state) { + try { + std::map symbol_id_names; + for (const auto & kv : state.symbol_ids) { + symbol_id_names[kv.second] = kv.first; + } + for (size_t i = 0, end = state.rules.size(); i < end; i++) { + // fprintf(file, "%zu: ", i); + // print_rule_binary(file, state.rules[i]); + print_rule(file, uint32_t(i), state.rules[i], symbol_id_names); + // fprintf(file, "\n"); + } + } catch (const std::exception & err) { + fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); + } + } + + std::vector parse_state::c_rules() { + std::vector ret; + ret.reserve(rules.size()); + for (const auto & rule : rules) { + ret.push_back(rule.data()); + } + return ret; + } +} diff --git a/applications/src/llama/common/grammar-parser.h b/applications/src/llama/common/grammar-parser.h new file mode 100644 index 000000000..9037d7272 --- /dev/null +++ b/applications/src/llama/common/grammar-parser.h @@ -0,0 +1,29 @@ +// Implements a parser for an extended Backus-Naur form (BNF), producing the +// binary context-free grammar format specified by llama.h. Supports character +// ranges, grouping, and repetition operators. As an example, a grammar for +// arithmetic might look like: +// +// root ::= expr +// expr ::= term ([-+*/] term)* +// term ::= num | "(" space expr ")" space +// num ::= [0-9]+ space +// space ::= [ \t\n]* + +#pragma once +#include "llama.h" +#include +#include +#include +#include + +namespace grammar_parser { + struct parse_state { + std::map symbol_ids; + std::vector> rules; + + std::vector c_rules(); + }; + + parse_state parse(const char * src); + void print_grammar(FILE * file, const parse_state & state); +} diff --git a/applications/src/llama/common/log.h b/applications/src/llama/common/log.h new file mode 100644 index 000000000..e4e1b9f4f --- /dev/null +++ b/applications/src/llama/common/log.h @@ -0,0 +1,723 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +// -------------------------------- +// +// Basic usage: +// +// -------- +// +// The LOG() and LOG_TEE() macros are ready to go by default +// they do not require any initialization. +// +// LOGLN() and LOG_TEELN() are variants which automatically +// include \n character at the end of the log string. +// +// LOG() behaves exactly like printf, by default writing to a logfile. +// LOG_TEE() additionally, prints to the screen too ( mimics Unix tee command ). +// +// Default logfile is named +// "llama..log" +// Default LOG_TEE() secondary output target is +// stderr +// +// Logs can be dynamically disabled or enabled using functions: +// log_disable() +// and +// log_enable() +// +// A log target can be changed with: +// log_set_target( string ) +// creating and opening, or re-opening a file by string filename +// or +// log_set_target( FILE* ) +// allowing to point at stderr, stdout, or any valid FILE* file handler. +// +// -------- +// +// End of Basic usage. +// +// -------------------------------- + +// Specifies a log target. +// default uses log_handler() with "llama.log" log file +// this can be changed, by defining LOG_TARGET +// like so: +// +// #define LOG_TARGET (a valid FILE*) +// #include "log.h" +// +// or it can be simply redirected to stdout or stderr +// like so: +// +// #define LOG_TARGET stderr +// #include "log.h" +// +// The log target can also be redirected to a different function +// like so: +// +// #define LOG_TARGET log_handler_different() +// #include "log.h" +// +// FILE* log_handler_different() +// { +// return stderr; +// } +// +// or: +// +// #define LOG_TARGET log_handler_another_one("somelog.log") +// #include "log.h" +// +// FILE* log_handler_another_one(char*filename) +// { +// static FILE* logfile = nullptr; +// (...) +// if( !logfile ) +// { +// fopen(...) +// } +// (...) +// return logfile +// } +// +#ifndef LOG_TARGET + #define LOG_TARGET log_handler() +#endif + +#ifndef LOG_TEE_TARGET + #define LOG_TEE_TARGET stderr +#endif + +// Utility for synchronizing log configuration state +// since std::optional was introduced only in c++17 +enum LogTriState +{ + LogTriStateSame, + LogTriStateFalse, + LogTriStateTrue +}; + +// Utility to obtain "pid" like unique process id and use it when creating log files. +inline std::string log_get_pid() +{ + static std::string pid; + if (pid.empty()) + { + // std::this_thread::get_id() is the most portable way of obtaining a "process id" + // it's not the same as "pid" but is unique enough to solve multiple instances + // trying to write to the same log. + std::stringstream ss; + ss << std::this_thread::get_id(); + pid = ss.str(); + } + + return pid; +} + +// Utility function for generating log file names with unique id based on thread id. +// invocation with log_filename_generator( "llama", "log" ) creates a string "llama..log" +// where the number is a runtime id of the current thread. + +#define log_filename_generator(log_file_basename, log_file_extension) log_filename_generator_impl(LogTriStateSame, log_file_basename, log_file_extension) + +// INTERNAL, DO NOT USE +inline std::string log_filename_generator_impl(LogTriState multilog, const std::string & log_file_basename, const std::string & log_file_extension) +{ + static bool _multilog = false; + + if (multilog != LogTriStateSame) + { + _multilog = multilog == LogTriStateTrue; + } + + std::stringstream buf; + + buf << log_file_basename; + if (_multilog) + { + buf << "."; + buf << log_get_pid(); + } + buf << "."; + buf << log_file_extension; + + return buf.str(); +} + +#ifndef LOG_DEFAULT_FILE_NAME + #define LOG_DEFAULT_FILE_NAME log_filename_generator("llama", "log") +#endif + +// Utility for turning #define values into string literals +// so we can have a define for stderr and +// we can print "stderr" instead of literal stderr, etc. +#define LOG_STRINGIZE1(s) #s +#define LOG_STRINGIZE(s) LOG_STRINGIZE1(s) + +#define LOG_TEE_TARGET_STRING LOG_STRINGIZE(LOG_TEE_TARGET) + +// Allows disabling timestamps. +// in order to disable, define LOG_NO_TIMESTAMPS +// like so: +// +// #define LOG_NO_TIMESTAMPS +// #include "log.h" +// +#ifndef LOG_NO_TIMESTAMPS + #ifndef _MSC_VER + #define LOG_TIMESTAMP_FMT "[%" PRIu64 "] " + #define LOG_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() + #else + #define LOG_TIMESTAMP_FMT "[%" PRIu64 "] " + #define LOG_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() + #endif +#else + #define LOG_TIMESTAMP_FMT "%s" + #define LOG_TIMESTAMP_VAL ,"" +#endif + +#ifdef LOG_TEE_TIMESTAMPS + #ifndef _MSC_VER + #define LOG_TEE_TIMESTAMP_FMT "[%" PRIu64 "] " + #define LOG_TEE_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() + #else + #define LOG_TEE_TIMESTAMP_FMT "[%" PRIu64 "] " + #define LOG_TEE_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() + #endif +#else + #define LOG_TEE_TIMESTAMP_FMT "%s" + #define LOG_TEE_TIMESTAMP_VAL ,"" +#endif + +// Allows disabling file/line/function prefix +// in order to disable, define LOG_NO_FILE_LINE_FUNCTION +// like so: +// +// #define LOG_NO_FILE_LINE_FUNCTION +// #include "log.h" +// +#ifndef LOG_NO_FILE_LINE_FUNCTION + #ifndef _MSC_VER + #define LOG_FLF_FMT "[%24s:%5d][%24s] " + #define LOG_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ + #else + #define LOG_FLF_FMT "[%24s:%5ld][%24s] " + #define LOG_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ + #endif +#else + #define LOG_FLF_FMT "%s" + #define LOG_FLF_VAL ,"" +#endif + +#ifdef LOG_TEE_FILE_LINE_FUNCTION + #ifndef _MSC_VER + #define LOG_TEE_FLF_FMT "[%24s:%5d][%24s] " + #define LOG_TEE_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ + #else + #define LOG_TEE_FLF_FMT "[%24s:%5ld][%24s] " + #define LOG_TEE_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ + #endif +#else + #define LOG_TEE_FLF_FMT "%s" + #define LOG_TEE_FLF_VAL ,"" +#endif + +// INTERNAL, DO NOT USE +// USE LOG() INSTEAD +// +#ifndef _MSC_VER + #define LOG_IMPL(str, ...) \ + do { \ + if (LOG_TARGET != nullptr) \ + { \ + fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL, __VA_ARGS__); \ + fflush(LOG_TARGET); \ + } \ + } while (0) +#else + #define LOG_IMPL(str, ...) \ + do { \ + if (LOG_TARGET != nullptr) \ + { \ + fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL "", ##__VA_ARGS__); \ + fflush(LOG_TARGET); \ + } \ + } while (0) +#endif + +// INTERNAL, DO NOT USE +// USE LOG_TEE() INSTEAD +// +#ifndef _MSC_VER + #define LOG_TEE_IMPL(str, ...) \ + do { \ + if (LOG_TARGET != nullptr) \ + { \ + fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL, __VA_ARGS__); \ + fflush(LOG_TARGET); \ + } \ + if (LOG_TARGET != nullptr && LOG_TARGET != stdout && LOG_TARGET != stderr && LOG_TEE_TARGET != nullptr) \ + { \ + fprintf(LOG_TEE_TARGET, LOG_TEE_TIMESTAMP_FMT LOG_TEE_FLF_FMT str "%s" LOG_TEE_TIMESTAMP_VAL LOG_TEE_FLF_VAL, __VA_ARGS__); \ + fflush(LOG_TEE_TARGET); \ + } \ + } while (0) +#else + #define LOG_TEE_IMPL(str, ...) \ + do { \ + if (LOG_TARGET != nullptr) \ + { \ + fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL "", ##__VA_ARGS__); \ + fflush(LOG_TARGET); \ + } \ + if (LOG_TARGET != nullptr && LOG_TARGET != stdout && LOG_TARGET != stderr && LOG_TEE_TARGET != nullptr) \ + { \ + fprintf(LOG_TEE_TARGET, LOG_TEE_TIMESTAMP_FMT LOG_TEE_FLF_FMT str "%s" LOG_TEE_TIMESTAMP_VAL LOG_TEE_FLF_VAL "", ##__VA_ARGS__); \ + fflush(LOG_TEE_TARGET); \ + } \ + } while (0) +#endif + +// The '\0' as a last argument, is a trick to bypass the silly +// "warning: ISO C++11 requires at least one argument for the "..." in a variadic macro" +// so we can have a single macro which can be called just like printf. + +// Main LOG macro. +// behaves like printf, and supports arguments the exact same way. +// +#ifndef _MSC_VER + #define LOG(...) LOG_IMPL(__VA_ARGS__, "") +#else + #define LOG(str, ...) LOG_IMPL("%s" str, "", __VA_ARGS__, "") +#endif + +// Main TEE macro. +// does the same as LOG +// and +// simultaneously writes stderr. +// +// Secondary target can be changed just like LOG_TARGET +// by defining LOG_TEE_TARGET +// +#ifndef _MSC_VER + #define LOG_TEE(...) LOG_TEE_IMPL(__VA_ARGS__, "") +#else + #define LOG_TEE(str, ...) LOG_TEE_IMPL("%s" str, "", __VA_ARGS__, "") +#endif + +// LOG macro variants with auto endline. +#ifndef _MSC_VER + #define LOGLN(...) LOG_IMPL(__VA_ARGS__, "\n") + #define LOG_TEELN(...) LOG_TEE_IMPL(__VA_ARGS__, "\n") +#else + #define LOGLN(str, ...) LOG_IMPL("%s" str, "", __VA_ARGS__, "\n") + #define LOG_TEELN(str, ...) LOG_TEE_IMPL("%s" str, "", __VA_ARGS__, "\n") +#endif + +// INTERNAL, DO NOT USE +inline FILE *log_handler1_impl(bool change = false, LogTriState append = LogTriStateSame, LogTriState disable = LogTriStateSame, const std::string & filename = LOG_DEFAULT_FILE_NAME, FILE *target = nullptr) +{ + static bool _initialized = false; + static bool _append = false; + static bool _disabled = filename.empty() && target == nullptr; + static std::string log_current_filename{filename}; + static FILE *log_current_target{target}; + static FILE *logfile = nullptr; + + if (change) + { + if (append != LogTriStateSame) + { + _append = append == LogTriStateTrue; + return logfile; + } + + if (disable == LogTriStateTrue) + { + // Disable primary target + _disabled = true; + } + // If previously disabled, only enable, and keep previous target + else if (disable == LogTriStateFalse) + { + _disabled = false; + } + // Otherwise, process the arguments + else if (log_current_filename != filename || log_current_target != target) + { + _initialized = false; + } + } + + if (_disabled) + { + // Log is disabled + return nullptr; + } + + if (_initialized) + { + // with fallback in case something went wrong + return logfile ? logfile : stderr; + } + + // do the (re)initialization + if (target != nullptr) + { + if (logfile != nullptr && logfile != stdout && logfile != stderr) + { + fclose(logfile); + } + + log_current_filename = LOG_DEFAULT_FILE_NAME; + log_current_target = target; + + logfile = target; + } + else + { + if (log_current_filename != filename) + { + if (logfile != nullptr && logfile != stdout && logfile != stderr) + { + fclose(logfile); + } + } + + logfile = fopen(filename.c_str(), _append ? "a" : "w"); + } + + if (!logfile) + { + // Verify whether the file was opened, otherwise fallback to stderr + logfile = stderr; + + fprintf(stderr, "Failed to open logfile '%s' with error '%s'\n", filename.c_str(), std::strerror(errno)); + fflush(stderr); + + // At this point we let the init flag be to true below, and let the target fallback to stderr + // otherwise we would repeatedly fopen() which was already unsuccessful + } + + _initialized = true; + + return logfile ? logfile : stderr; +} + +// INTERNAL, DO NOT USE +inline FILE *log_handler2_impl(bool change = false, LogTriState append = LogTriStateSame, LogTriState disable = LogTriStateSame, FILE *target = nullptr, const std::string & filename = LOG_DEFAULT_FILE_NAME) +{ + return log_handler1_impl(change, append, disable, filename, target); +} + +// Disables logs entirely at runtime. +// Makes LOG() and LOG_TEE() produce no output, +// until enabled back. +#define log_disable() log_disable_impl() + +// INTERNAL, DO NOT USE +inline FILE *log_disable_impl() +{ + return log_handler1_impl(true, LogTriStateSame, LogTriStateTrue); +} + +// Enables logs at runtime. +#define log_enable() log_enable_impl() + +// INTERNAL, DO NOT USE +inline FILE *log_enable_impl() +{ + return log_handler1_impl(true, LogTriStateSame, LogTriStateFalse); +} + +// Sets target fir logs, either by a file name or FILE* pointer (stdout, stderr, or any valid FILE*) +#define log_set_target(target) log_set_target_impl(target) + +// INTERNAL, DO NOT USE +inline FILE *log_set_target_impl(const std::string & filename) { return log_handler1_impl(true, LogTriStateSame, LogTriStateSame, filename); } +inline FILE *log_set_target_impl(FILE *target) { return log_handler2_impl(true, LogTriStateSame, LogTriStateSame, target); } + +// INTERNAL, DO NOT USE +inline FILE *log_handler() { return log_handler1_impl(); } + +// Enable or disable creating separate log files for each run. +// can ONLY be invoked BEFORE first log use. +#define log_multilog(enable) log_filename_generator_impl((enable) ? LogTriStateTrue : LogTriStateFalse, "", "") +// Enable or disable append mode for log file. +// can ONLY be invoked BEFORE first log use. +#define log_append(enable) log_append_impl(enable) +// INTERNAL, DO NOT USE +inline FILE *log_append_impl(bool enable) +{ + return log_handler1_impl(true, enable ? LogTriStateTrue : LogTriStateFalse, LogTriStateSame); +} + +inline void log_test() +{ + log_disable(); + LOG("01 Hello World to nobody, because logs are disabled!\n"); + log_enable(); + LOG("02 Hello World to default output, which is \"%s\" ( Yaaay, arguments! )!\n", LOG_STRINGIZE(LOG_TARGET)); + LOG_TEE("03 Hello World to **both** default output and " LOG_TEE_TARGET_STRING "!\n"); + log_set_target(stderr); + LOG("04 Hello World to stderr!\n"); + LOG_TEE("05 Hello World TEE with double printing to stderr prevented!\n"); + log_set_target(LOG_DEFAULT_FILE_NAME); + LOG("06 Hello World to default log file!\n"); + log_set_target(stdout); + LOG("07 Hello World to stdout!\n"); + log_set_target(LOG_DEFAULT_FILE_NAME); + LOG("08 Hello World to default log file again!\n"); + log_disable(); + LOG("09 Hello World _1_ into the void!\n"); + log_enable(); + LOG("10 Hello World back from the void ( you should not see _1_ in the log or the output )!\n"); + log_disable(); + log_set_target("llama.anotherlog.log"); + LOG("11 Hello World _2_ to nobody, new target was selected but logs are still disabled!\n"); + log_enable(); + LOG("12 Hello World this time in a new file ( you should not see _2_ in the log or the output )?\n"); + log_set_target("llama.yetanotherlog.log"); + LOG("13 Hello World this time in yet new file?\n"); + log_set_target(log_filename_generator("llama_autonamed", "log")); + LOG("14 Hello World in log with generated filename!\n"); +#ifdef _MSC_VER + LOG_TEE("15 Hello msvc TEE without arguments\n"); + LOG_TEE("16 Hello msvc TEE with (%d)(%s) arguments\n", 1, "test"); + LOG_TEELN("17 Hello msvc TEELN without arguments\n"); + LOG_TEELN("18 Hello msvc TEELN with (%d)(%s) arguments\n", 1, "test"); + LOG("19 Hello msvc LOG without arguments\n"); + LOG("20 Hello msvc LOG with (%d)(%s) arguments\n", 1, "test"); + LOGLN("21 Hello msvc LOGLN without arguments\n"); + LOGLN("22 Hello msvc LOGLN with (%d)(%s) arguments\n", 1, "test"); +#endif +} + +inline bool log_param_single_parse(const std::string & param) +{ + if ( param == "--log-test") + { + log_test(); + return true; + } + + if ( param == "--log-disable") + { + log_disable(); + return true; + } + + if ( param == "--log-enable") + { + log_enable(); + return true; + } + + if (param == "--log-new") + { + log_multilog(true); + return true; + } + + if (param == "--log-append") + { + log_append(true); + return true; + } + + return false; +} + +inline bool log_param_pair_parse(bool check_but_dont_parse, const std::string & param, const std::string & next = std::string()) +{ + if ( param == "--log-file") + { + if (!check_but_dont_parse) + { + log_set_target(log_filename_generator(next.empty() ? "unnamed" : next, "log")); + } + + return true; + } + + return false; +} + +inline void log_print_usage() +{ + printf("log options:\n"); + /* format + printf(" -h, --help show this help message and exit\n");*/ + /* spacing + printf("__-param----------------Description\n");*/ + printf(" --log-test Run simple logging test\n"); + printf(" --log-disable Disable trace logs\n"); + printf(" --log-enable Enable trace logs\n"); + printf(" --log-file Specify a log filename (without extension)\n"); + printf(" --log-new Create a separate new log file on start. " + "Each log file will have unique name: \"..log\"\n"); + printf(" --log-append Don't truncate the old log file.\n"); +} + +#define log_dump_cmdline(argc, argv) log_dump_cmdline_impl(argc, argv) + +// INTERNAL, DO NOT USE +inline void log_dump_cmdline_impl(int argc, char **argv) +{ + std::stringstream buf; + for (int i = 0; i < argc; ++i) + { + if (std::string(argv[i]).find(' ') != std::string::npos) + { + buf << " \"" << argv[i] <<"\""; + } + else + { + buf << " " << argv[i]; + } + } + LOGLN("Cmd:%s", buf.str().c_str()); +} + +#define log_tostr(var) log_var_to_string_impl(var).c_str() + +inline std::string log_var_to_string_impl(bool var) +{ + return var ? "true" : "false"; +} + +inline std::string log_var_to_string_impl(std::string var) +{ + return var; +} + +inline std::string log_var_to_string_impl(const std::vector & var) +{ + std::stringstream buf; + buf << "[ "; + bool first = true; + for (auto e : var) + { + if (first) + { + first = false; + } + else + { + buf << ", "; + } + buf << std::to_string(e); + } + buf << " ]"; + + return buf.str(); +} + +template +inline std::string LOG_TOKENS_TOSTR_PRETTY(const C & ctx, const T & tokens) +{ + std::stringstream buf; + buf << "[ "; + + bool first = true; + for (const auto &token : tokens) + { + if (!first) { + buf << ", "; + } else { + first = false; + } + + auto detokenized = llama_token_to_piece(ctx, token); + + detokenized.erase( + std::remove_if( + detokenized.begin(), + detokenized.end(), + [](const unsigned char c) { return !std::isprint(c); }), + detokenized.end()); + + buf + << "'" << detokenized << "'" + << ":" << std::to_string(token); + } + buf << " ]"; + + return buf.str(); +} + +template +inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch) +{ + std::stringstream buf; + buf << "[ "; + + bool first = true; + for (int i = 0; i < batch.n_tokens; ++i) + { + if (!first) { + buf << ", "; + } else { + first = false; + } + + auto detokenized = llama_token_to_piece(ctx, batch.token[i]); + + detokenized.erase( + std::remove_if( + detokenized.begin(), + detokenized.end(), + [](const unsigned char c) { return !std::isprint(c); }), + detokenized.end()); + + buf + << "\n" << std::to_string(i) + << ":token '" << detokenized << "'" + << ":pos " << std::to_string(batch.pos[i]) + << ":n_seq_id " << std::to_string(batch.n_seq_id[i]) + << ":seq_id " << std::to_string(batch.seq_id[i][0]) + << ":logits " << std::to_string(batch.logits[i]); + } + buf << " ]"; + + return buf.str(); +} + +#ifdef LOG_DISABLE_LOGS + +#undef LOG +#define LOG(...) // dummy stub +#undef LOGLN +#define LOGLN(...) // dummy stub + +#undef LOG_TEE +#define LOG_TEE(...) fprintf(stderr, __VA_ARGS__) // convert to normal fprintf + +#undef LOG_TEELN +#define LOG_TEELN(...) fprintf(stderr, __VA_ARGS__) // convert to normal fprintf + +#undef LOG_DISABLE +#define LOG_DISABLE() // dummy stub + +#undef LOG_ENABLE +#define LOG_ENABLE() // dummy stub + +#undef LOG_ENABLE +#define LOG_ENABLE() // dummy stub + +#undef LOG_SET_TARGET +#define LOG_SET_TARGET(...) // dummy stub + +#undef LOG_DUMP_CMDLINE +#define LOG_DUMP_CMDLINE(...) // dummy stub + +#endif // LOG_DISABLE_LOGS diff --git a/applications/src/llama/common/sampling.cpp b/applications/src/llama/common/sampling.cpp new file mode 100644 index 000000000..f4e76df31 --- /dev/null +++ b/applications/src/llama/common/sampling.cpp @@ -0,0 +1,269 @@ +#include "sampling.h" + +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { + struct llama_sampling_context * result = new llama_sampling_context(); + + result->params = params; + result->grammar = nullptr; + + // if there is a grammar, parse it + if (!params.grammar.empty()) { + result->parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + + // will be empty (default) if there are parse errors + if (result->parsed_grammar.rules.empty()) { + fprintf(stderr, "%s: failed to parse grammar\n", __func__); + return nullptr; + } + + std::vector grammar_rules(result->parsed_grammar.c_rules()); + + result->grammar = llama_grammar_init( + grammar_rules.data(), + grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); + } + + result->prev.resize(params.n_prev); + + return result; +} + +void llama_sampling_free(struct llama_sampling_context * ctx) { + if (ctx->grammar != NULL) { + llama_grammar_free(ctx->grammar); + } + + delete ctx; +} + +void llama_sampling_reset(llama_sampling_context * ctx) { + if (ctx->grammar != NULL) { + llama_grammar_free(ctx->grammar); + ctx->grammar = NULL; + } + + if (!ctx->parsed_grammar.rules.empty()) { + std::vector grammar_rules(ctx->parsed_grammar.c_rules()); + + ctx->grammar = llama_grammar_init( + grammar_rules.data(), + grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); + } + + std::fill(ctx->prev.begin(), ctx->prev.end(), 0); + ctx->cur.clear(); +} + +void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { + if (dst->grammar) { + llama_grammar_free(dst->grammar); + dst->grammar = nullptr; + } + + if (src->grammar) { + dst->grammar = llama_grammar_copy(src->grammar); + } + + dst->prev = src->prev; +} + +llama_token llama_sampling_last(llama_sampling_context * ctx) { + return ctx->prev.back(); +} + +std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) { + const int size = ctx_sampling->prev.size(); + + n = std::min(n, size); + + std::string result; + + for (int i = size - n; i < size; i++) { + result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]); + } + + return result; +} + +std::string llama_sampling_print(const llama_sampling_params & params) { + char result[1024]; + + snprintf(result, sizeof(result), + "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" + "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" + "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", + params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, + params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, + params.mirostat, params.mirostat_eta, params.mirostat_tau); + + return std::string(result); +} + +std::string llama_sampling_order_print(const llama_sampling_params & params) { + std::string result = "CFG -> Penalties "; + if (params.mirostat == 0) { + for (auto s : params.samplers_sequence) { + switch (s) { + case 'k': result += "-> top_k "; break; + case 'f': result += "-> tfs_z "; break; + case 'y': result += "-> typical_p "; break; + case 'p': result += "-> top_p "; break; + case 'm': result += "-> min_p "; break; + case 't': result += "-> temp "; break; + default : break; + } + } + } else { + result += "-> mirostat "; + } + + return result; +} + +// no reasons to expose this function in header +static void sampler_queue( + struct llama_context * ctx_main, + const llama_sampling_params & params, + llama_token_data_array & cur_p, + size_t & min_keep) { + const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + + const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; + const float top_p = params.top_p; + const float min_p = params.min_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + const std::string & samplers_sequence = params.samplers_sequence; + + for (auto s : samplers_sequence) { + switch (s){ + case 'k': llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; + case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; + case 'y': llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; + case 'p': llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; + case 'm': llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; + case 't': llama_sample_temp (ctx_main, &cur_p, temp); break; + default : break; + } + } +} + +llama_token llama_sampling_sample( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx) { + const llama_sampling_params & params = ctx_sampling->params; + + const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + + const float temp = params.temp; + const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; + const float penalty_repeat = params.penalty_repeat; + const float penalty_freq = params.penalty_freq; + const float penalty_present = params.penalty_present; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; + const bool penalize_nl = params.penalize_nl; + + auto & prev = ctx_sampling->prev; + auto & cur = ctx_sampling->cur; + + llama_token id = 0; + + float * logits = llama_get_logits_ith(ctx_main, idx); + + // apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { + logits[it->first] += it->second; + } + + cur.clear(); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + + if (ctx_cfg) { + llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_cfg, params.cfg_scale); + } + + // apply penalties + if (!prev.empty()) { + const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; + + llama_sample_repetition_penalties(ctx_main, &cur_p, + prev.data() + prev.size() - penalty_last_n, + penalty_last_n, penalty_repeat, penalty_freq, penalty_present); + + if (!penalize_nl) { + for (size_t idx = 0; idx < cur_p.size; idx++) { + if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { + cur_p.data[idx].logit = nl_logit; + break; + } + } + } + } + + if (ctx_sampling->grammar != NULL) { + llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); + } + + if (temp < 0.0) { + // greedy sampling, with probs + llama_sample_softmax(ctx_main, &cur_p); + id = cur_p.data[0].id; + } else if (temp == 0.0) { + // greedy sampling, no probs + id = llama_sample_token_greedy(ctx_main, &cur_p); + } else { + if (mirostat == 1) { + const int mirostat_m = 100; + llama_sample_temp(ctx_main, &cur_p, temp); + id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu); + } else if (mirostat == 2) { + llama_sample_temp(ctx_main, &cur_p, temp); + id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); + } else { + // temperature sampling + size_t min_keep = std::max(1, params.n_probs); + + sampler_queue(ctx_main, params, cur_p, min_keep); + + id = llama_sample_token(ctx_main, &cur_p); + + //{ + // const int n_top = 10; + // LOG("top %d candidates:\n", n_top); + + // for (int i = 0; i < n_top; i++) { + // const llama_token id = cur_p.data[i].id; + // (void)id; // To avoid a warning that id is unused when logging is disabled. + // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p); + // } + //} + + LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str()); + } + } + + return id; +} + +void llama_sampling_accept( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + llama_token id, + bool apply_grammar) { + ctx_sampling->prev.erase(ctx_sampling->prev.begin()); + ctx_sampling->prev.push_back(id); + + if (ctx_sampling->grammar != NULL && apply_grammar) { + llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); + } +} diff --git a/applications/src/llama/common/sampling.h b/applications/src/llama/common/sampling.h new file mode 100644 index 000000000..fdfa9eed1 --- /dev/null +++ b/applications/src/llama/common/sampling.h @@ -0,0 +1,114 @@ +#pragma once + +#include "llama.h" + +#include "grammar-parser.h" + +#include +#include +#include + +// sampling parameters +typedef struct llama_sampling_params { + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float tfs_z = 1.00f; // 1.0 = disabled + float typical_p = 1.00f; // 1.0 = disabled + float temp = 0.80f; // 1.0 = disabled + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.10f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool penalize_nl = true; // consider newlines as a repeatable token + std::string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp + + std::string grammar; // optional BNF-like grammar to constrain sampling + + // Classifier-Free Guidance + // https://arxiv.org/abs/2306.17806 + std::string cfg_negative_prompt; // string to help guidance + float cfg_scale = 1.f; // how strong is guidance + + std::unordered_map logit_bias; // logit bias for specific tokens +} llama_sampling_params; + +// general sampler context +// TODO: move to llama.h +struct llama_sampling_context { + // parameters that will be used for sampling + llama_sampling_params params; + + // mirostat sampler state + float mirostat_mu; + + llama_grammar * grammar; + + // internal + grammar_parser::parse_state parsed_grammar; + + // TODO: replace with ring-buffer + std::vector prev; + std::vector cur; +}; + +#include "common.h" + +// Create a new sampling context instance. +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params); + +void llama_sampling_free(struct llama_sampling_context * ctx); + +// Reset the sampler context +// - clear prev tokens +// - reset grammar +void llama_sampling_reset(llama_sampling_context * ctx); + +// Copy the sampler context +void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); + +// Get the last sampled token +llama_token llama_sampling_last(llama_sampling_context * ctx); + +// Get a string representation of the last sampled tokens +std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n); + +// Print sampling parameters into a string +std::string llama_sampling_print(const llama_sampling_params & params); + +// Print sampling order into a string +std::string llama_sampling_order_print(const llama_sampling_params & params); + +// this is a common sampling function used across the examples for convenience +// it can serve as a starting point for implementing your own sampling function +// Note: When using multiple sequences, it is the caller's responsibility to call +// llama_sampling_reset when a sequence ends +// +// required: +// - ctx_main: context to use for sampling +// - ctx_sampling: sampling-specific context +// +// optional: +// - ctx_cfg: context to use for classifier-free guidance +// - idx: sample from llama_get_logits_ith(ctx, idx) +// +// returns: +// - token: sampled token +// - candidates: vector of candidate tokens +// +llama_token llama_sampling_sample( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + int idx = 0); + +void llama_sampling_accept( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + llama_token id, + bool apply_grammar); diff --git a/applications/src/llama/common/stb_image.h b/applications/src/llama/common/stb_image.h new file mode 100644 index 000000000..4766d7e67 --- /dev/null +++ b/applications/src/llama/common/stb_image.h @@ -0,0 +1,8396 @@ +/* stb_image - v2.28 - public domain image loader - http://nothings.org/stb + no warranty implied; use at your own risk + + Do this: + #define STB_IMAGE_IMPLEMENTATION + before you include this file in *one* C or C++ file to create the implementation. + + // i.e. it should look like this: + #include ... + #include ... + #include ... + #define STB_IMAGE_IMPLEMENTATION + #include "stb_image.h" + + You can #define STBI_ASSERT(x) before the #include to avoid using assert.h. + And #define STBI_MALLOC, STBI_REALLOC, and STBI_FREE to avoid using malloc,realloc,free + + + QUICK NOTES: + Primarily of interest to game developers and other people who can + avoid problematic images and only need the trivial interface + + JPEG baseline & progressive (12 bpc/arithmetic not supported, same as stock IJG lib) + PNG 1/2/4/8/16-bit-per-channel + + TGA (not sure what subset, if a subset) + BMP non-1bpp, non-RLE + PSD (composited view only, no extra channels, 8/16 bit-per-channel) + + GIF (*comp always reports as 4-channel) + HDR (radiance rgbE format) + PIC (Softimage PIC) + PNM (PPM and PGM binary only) + + Animated GIF still needs a proper API, but here's one way to do it: + http://gist.github.com/urraka/685d9a6340b26b830d49 + + - decode from memory or through FILE (define STBI_NO_STDIO to remove code) + - decode from arbitrary I/O callbacks + - SIMD acceleration on x86/x64 (SSE2) and ARM (NEON) + + Full documentation under "DOCUMENTATION" below. + + +LICENSE + + See end of file for license information. + +RECENT REVISION HISTORY: + + 2.28 (2023-01-29) many error fixes, security errors, just tons of stuff + 2.27 (2021-07-11) document stbi_info better, 16-bit PNM support, bug fixes + 2.26 (2020-07-13) many minor fixes + 2.25 (2020-02-02) fix warnings + 2.24 (2020-02-02) fix warnings; thread-local failure_reason and flip_vertically + 2.23 (2019-08-11) fix clang static analysis warning + 2.22 (2019-03-04) gif fixes, fix warnings + 2.21 (2019-02-25) fix typo in comment + 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs + 2.19 (2018-02-11) fix warning + 2.18 (2018-01-30) fix warnings + 2.17 (2018-01-29) bugfix, 1-bit BMP, 16-bitness query, fix warnings + 2.16 (2017-07-23) all functions have 16-bit variants; optimizations; bugfixes + 2.15 (2017-03-18) fix png-1,2,4; all Imagenet JPGs; no runtime SSE detection on GCC + 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs + 2.13 (2016-12-04) experimental 16-bit API, only for PNG so far; fixes + 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes + 2.11 (2016-04-02) 16-bit PNGS; enable SSE2 in non-gcc x64 + RGB-format JPEG; remove white matting in PSD; + allocate large structures on the stack; + correct channel count for PNG & BMP + 2.10 (2016-01-22) avoid warning introduced in 2.09 + 2.09 (2016-01-16) 16-bit TGA; comments in PNM files; STBI_REALLOC_SIZED + + See end of file for full revision history. + + + ============================ Contributors ========================= + + Image formats Extensions, features + Sean Barrett (jpeg, png, bmp) Jetro Lauha (stbi_info) + Nicolas Schulz (hdr, psd) Martin "SpartanJ" Golini (stbi_info) + Jonathan Dummer (tga) James "moose2000" Brown (iPhone PNG) + Jean-Marc Lienher (gif) Ben "Disch" Wenger (io callbacks) + Tom Seddon (pic) Omar Cornut (1/2/4-bit PNG) + Thatcher Ulrich (psd) Nicolas Guillemot (vertical flip) + Ken Miller (pgm, ppm) Richard Mitton (16-bit PSD) + github:urraka (animated gif) Junggon Kim (PNM comments) + Christopher Forseth (animated gif) Daniel Gibson (16-bit TGA) + socks-the-fox (16-bit PNG) + Jeremy Sawicki (handle all ImageNet JPGs) + Optimizations & bugfixes Mikhail Morozov (1-bit BMP) + Fabian "ryg" Giesen Anael Seghezzi (is-16-bit query) + Arseny Kapoulkine Simon Breuss (16-bit PNM) + John-Mark Allen + Carmelo J Fdez-Aguera + + Bug & warning fixes + Marc LeBlanc David Woo Guillaume George Martins Mozeiko + Christpher Lloyd Jerry Jansson Joseph Thomson Blazej Dariusz Roszkowski + Phil Jordan Dave Moore Roy Eltham + Hayaki Saito Nathan Reed Won Chun + Luke Graham Johan Duparc Nick Verigakis the Horde3D community + Thomas Ruf Ronny Chevalier github:rlyeh + Janez Zemva John Bartholomew Michal Cichon github:romigrou + Jonathan Blow Ken Hamada Tero Hanninen github:svdijk + Eugene Golushkov Laurent Gomila Cort Stratton github:snagar + Aruelien Pocheville Sergio Gonzalez Thibault Reuille github:Zelex + Cass Everitt Ryamond Barbiero github:grim210 + Paul Du Bois Engin Manap Aldo Culquicondor github:sammyhw + Philipp Wiesemann Dale Weiler Oriol Ferrer Mesia github:phprus + Josh Tobin Neil Bickford Matthew Gregan github:poppolopoppo + Julian Raschke Gregory Mullen Christian Floisand github:darealshinji + Baldur Karlsson Kevin Schmidt JR Smith github:Michaelangel007 + Brad Weinberger Matvey Cherevko github:mosra + Luca Sas Alexander Veselov Zack Middleton [reserved] + Ryan C. Gordon [reserved] [reserved] + DO NOT ADD YOUR NAME HERE + + Jacko Dirks + + To add your name to the credits, pick a random blank space in the middle and fill it. + 80% of merge conflicts on stb PRs are due to people adding their name at the end + of the credits. +*/ + +#ifndef STBI_INCLUDE_STB_IMAGE_H +#define STBI_INCLUDE_STB_IMAGE_H + +// DOCUMENTATION +// +// Limitations: +// - no 12-bit-per-channel JPEG +// - no JPEGs with arithmetic coding +// - GIF always returns *comp=4 +// +// Basic usage (see HDR discussion below for HDR usage): +// int x,y,n; +// unsigned char *data = stbi_load(filename, &x, &y, &n, 0); +// // ... process data if not NULL ... +// // ... x = width, y = height, n = # 8-bit components per pixel ... +// // ... replace '0' with '1'..'4' to force that many components per pixel +// // ... but 'n' will always be the number that it would have been if you said 0 +// stbi_image_free(data); +// +// Standard parameters: +// int *x -- outputs image width in pixels +// int *y -- outputs image height in pixels +// int *channels_in_file -- outputs # of image components in image file +// int desired_channels -- if non-zero, # of image components requested in result +// +// The return value from an image loader is an 'unsigned char *' which points +// to the pixel data, or NULL on an allocation failure or if the image is +// corrupt or invalid. The pixel data consists of *y scanlines of *x pixels, +// with each pixel consisting of N interleaved 8-bit components; the first +// pixel pointed to is top-left-most in the image. There is no padding between +// image scanlines or between pixels, regardless of format. The number of +// components N is 'desired_channels' if desired_channels is non-zero, or +// *channels_in_file otherwise. If desired_channels is non-zero, +// *channels_in_file has the number of components that _would_ have been +// output otherwise. E.g. if you set desired_channels to 4, you will always +// get RGBA output, but you can check *channels_in_file to see if it's trivially +// opaque because e.g. there were only 3 channels in the source image. +// +// An output image with N components has the following components interleaved +// in this order in each pixel: +// +// N=#comp components +// 1 grey +// 2 grey, alpha +// 3 red, green, blue +// 4 red, green, blue, alpha +// +// If image loading fails for any reason, the return value will be NULL, +// and *x, *y, *channels_in_file will be unchanged. The function +// stbi_failure_reason() can be queried for an extremely brief, end-user +// unfriendly explanation of why the load failed. Define STBI_NO_FAILURE_STRINGS +// to avoid compiling these strings at all, and STBI_FAILURE_USERMSG to get slightly +// more user-friendly ones. +// +// Paletted PNG, BMP, GIF, and PIC images are automatically depalettized. +// +// To query the width, height and component count of an image without having to +// decode the full file, you can use the stbi_info family of functions: +// +// int x,y,n,ok; +// ok = stbi_info(filename, &x, &y, &n); +// // returns ok=1 and sets x, y, n if image is a supported format, +// // 0 otherwise. +// +// Note that stb_image pervasively uses ints in its public API for sizes, +// including sizes of memory buffers. This is now part of the API and thus +// hard to change without causing breakage. As a result, the various image +// loaders all have certain limits on image size; these differ somewhat +// by format but generally boil down to either just under 2GB or just under +// 1GB. When the decoded image would be larger than this, stb_image decoding +// will fail. +// +// Additionally, stb_image will reject image files that have any of their +// dimensions set to a larger value than the configurable STBI_MAX_DIMENSIONS, +// which defaults to 2**24 = 16777216 pixels. Due to the above memory limit, +// the only way to have an image with such dimensions load correctly +// is for it to have a rather extreme aspect ratio. Either way, the +// assumption here is that such larger images are likely to be malformed +// or malicious. If you do need to load an image with individual dimensions +// larger than that, and it still fits in the overall size limit, you can +// #define STBI_MAX_DIMENSIONS on your own to be something larger. +// +// =========================================================================== +// +// UNICODE: +// +// If compiling for Windows and you wish to use Unicode filenames, compile +// with +// #define STBI_WINDOWS_UTF8 +// and pass utf8-encoded filenames. Call stbi_convert_wchar_to_utf8 to convert +// Windows wchar_t filenames to utf8. +// +// =========================================================================== +// +// Philosophy +// +// stb libraries are designed with the following priorities: +// +// 1. easy to use +// 2. easy to maintain +// 3. good performance +// +// Sometimes I let "good performance" creep up in priority over "easy to maintain", +// and for best performance I may provide less-easy-to-use APIs that give higher +// performance, in addition to the easy-to-use ones. Nevertheless, it's important +// to keep in mind that from the standpoint of you, a client of this library, +// all you care about is #1 and #3, and stb libraries DO NOT emphasize #3 above all. +// +// Some secondary priorities arise directly from the first two, some of which +// provide more explicit reasons why performance can't be emphasized. +// +// - Portable ("ease of use") +// - Small source code footprint ("easy to maintain") +// - No dependencies ("ease of use") +// +// =========================================================================== +// +// I/O callbacks +// +// I/O callbacks allow you to read from arbitrary sources, like packaged +// files or some other source. Data read from callbacks are processed +// through a small internal buffer (currently 128 bytes) to try to reduce +// overhead. +// +// The three functions you must define are "read" (reads some bytes of data), +// "skip" (skips some bytes of data), "eof" (reports if the stream is at the end). +// +// =========================================================================== +// +// SIMD support +// +// The JPEG decoder will try to automatically use SIMD kernels on x86 when +// supported by the compiler. For ARM Neon support, you must explicitly +// request it. +// +// (The old do-it-yourself SIMD API is no longer supported in the current +// code.) +// +// On x86, SSE2 will automatically be used when available based on a run-time +// test; if not, the generic C versions are used as a fall-back. On ARM targets, +// the typical path is to have separate builds for NEON and non-NEON devices +// (at least this is true for iOS and Android). Therefore, the NEON support is +// toggled by a build flag: define STBI_NEON to get NEON loops. +// +// If for some reason you do not want to use any of SIMD code, or if +// you have issues compiling it, you can disable it entirely by +// defining STBI_NO_SIMD. +// +// =========================================================================== +// +// HDR image support (disable by defining STBI_NO_HDR) +// +// stb_image supports loading HDR images in general, and currently the Radiance +// .HDR file format specifically. You can still load any file through the existing +// interface; if you attempt to load an HDR file, it will be automatically remapped +// to LDR, assuming gamma 2.2 and an arbitrary scale factor defaulting to 1; +// both of these constants can be reconfigured through this interface: +// +// stbi_hdr_to_ldr_gamma(2.2f); +// stbi_hdr_to_ldr_scale(1.0f); +// +// (note, do not use _inverse_ constants; stbi_image will invert them +// appropriately). +// +// Additionally, there is a new, parallel interface for loading files as +// (linear) floats to preserve the full dynamic range: +// +// float *data = stbi_loadf(filename, &x, &y, &n, 0); +// +// If you load LDR images through this interface, those images will +// be promoted to floating point values, run through the inverse of +// constants corresponding to the above: +// +// stbi_ldr_to_hdr_scale(1.0f); +// stbi_ldr_to_hdr_gamma(2.2f); +// +// Finally, given a filename (or an open file or memory block--see header +// file for details) containing image data, you can query for the "most +// appropriate" interface to use (that is, whether the image is HDR or +// not), using: +// +// stbi_is_hdr(char *filename); +// +// =========================================================================== +// +// iPhone PNG support: +// +// We optionally support converting iPhone-formatted PNGs (which store +// premultiplied BGRA) back to RGB, even though they're internally encoded +// differently. To enable this conversion, call +// stbi_convert_iphone_png_to_rgb(1). +// +// Call stbi_set_unpremultiply_on_load(1) as well to force a divide per +// pixel to remove any premultiplied alpha *only* if the image file explicitly +// says there's premultiplied data (currently only happens in iPhone images, +// and only if iPhone convert-to-rgb processing is on). +// +// =========================================================================== +// +// ADDITIONAL CONFIGURATION +// +// - You can suppress implementation of any of the decoders to reduce +// your code footprint by #defining one or more of the following +// symbols before creating the implementation. +// +// STBI_NO_JPEG +// STBI_NO_PNG +// STBI_NO_BMP +// STBI_NO_PSD +// STBI_NO_TGA +// STBI_NO_GIF +// STBI_NO_HDR +// STBI_NO_PIC +// STBI_NO_PNM (.ppm and .pgm) +// +// - You can request *only* certain decoders and suppress all other ones +// (this will be more forward-compatible, as addition of new decoders +// doesn't require you to disable them explicitly): +// +// STBI_ONLY_JPEG +// STBI_ONLY_PNG +// STBI_ONLY_BMP +// STBI_ONLY_PSD +// STBI_ONLY_TGA +// STBI_ONLY_GIF +// STBI_ONLY_HDR +// STBI_ONLY_PIC +// STBI_ONLY_PNM (.ppm and .pgm) +// +// - If you use STBI_NO_PNG (or _ONLY_ without PNG), and you still +// want the zlib decoder to be available, #define STBI_SUPPORT_ZLIB +// +// - If you define STBI_MAX_DIMENSIONS, stb_image will reject images greater +// than that size (in either width or height) without further processing. +// This is to let programs in the wild set an upper bound to prevent +// denial-of-service attacks on untrusted data, as one could generate a +// valid image of gigantic dimensions and force stb_image to allocate a +// huge block of memory and spend disproportionate time decoding it. By +// default this is set to (1 << 24), which is 16777216, but that's still +// very big. + +#ifndef STBI_NO_STDIO +#include +#endif // STBI_NO_STDIO + +#define STBI_VERSION 1 + +enum { + STBI_default = 0, // only used for desired_channels + + STBI_grey = 1, + STBI_grey_alpha = 2, + STBI_rgb = 3, + STBI_rgb_alpha = 4 +}; + +#include +typedef unsigned char stbi_uc; +typedef unsigned short stbi_us; + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef STBIDEF +#ifdef STB_IMAGE_STATIC +#define STBIDEF static +#else +#define STBIDEF extern +#endif +#endif + +////////////////////////////////////////////////////////////////////////////// +// +// PRIMARY API - works on images of any type +// + +// +// load image by filename, open file, or memory buffer +// + +typedef struct { + int (*read)(void * user, char * data, + int size); // fill 'data' with 'size' bytes. return number of bytes actually read + void (*skip)(void * user, int n); // skip the next 'n' bytes, or 'unget' the last -n bytes if negative + int (*eof)(void * user); // returns nonzero if we are at end of file/data +} stbi_io_callbacks; + +//////////////////////////////////// +// +// 8-bits-per-channel interface +// + +STBIDEF stbi_uc * stbi_load_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, + int desired_channels); +STBIDEF stbi_uc * stbi_load_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, + int * channels_in_file, int desired_channels); + +#ifndef STBI_NO_STDIO +STBIDEF stbi_uc * stbi_load(char const * filename, int * x, int * y, int * channels_in_file, int desired_channels); +STBIDEF stbi_uc * stbi_load_from_file(FILE * f, int * x, int * y, int * channels_in_file, int desired_channels); +// for stbi_load_from_file, file pointer is left pointing immediately after image +#endif + +#ifndef STBI_NO_GIF +STBIDEF stbi_uc * stbi_load_gif_from_memory(stbi_uc const * buffer, int len, int ** delays, int * x, int * y, int * z, + int * comp, int req_comp); +#endif + +#ifdef STBI_WINDOWS_UTF8 +STBIDEF int stbi_convert_wchar_to_utf8(char * buffer, size_t bufferlen, const wchar_t * input); +#endif + +//////////////////////////////////// +// +// 16-bits-per-channel interface +// + +STBIDEF stbi_us * stbi_load_16_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, + int desired_channels); +STBIDEF stbi_us * stbi_load_16_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, + int * channels_in_file, int desired_channels); + +#ifndef STBI_NO_STDIO +STBIDEF stbi_us * stbi_load_16(char const * filename, int * x, int * y, int * channels_in_file, int desired_channels); +STBIDEF stbi_us * stbi_load_from_file_16(FILE * f, int * x, int * y, int * channels_in_file, int desired_channels); +#endif + +//////////////////////////////////// +// +// float-per-channel interface +// +#ifndef STBI_NO_LINEAR +STBIDEF float * stbi_loadf_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, + int desired_channels); +STBIDEF float * stbi_loadf_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * channels_in_file, + int desired_channels); + +#ifndef STBI_NO_STDIO +STBIDEF float * stbi_loadf(char const * filename, int * x, int * y, int * channels_in_file, int desired_channels); +STBIDEF float * stbi_loadf_from_file(FILE * f, int * x, int * y, int * channels_in_file, int desired_channels); +#endif +#endif + +#ifndef STBI_NO_HDR +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); +STBIDEF void stbi_hdr_to_ldr_scale(float scale); +#endif // STBI_NO_HDR + +#ifndef STBI_NO_LINEAR +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); +STBIDEF void stbi_ldr_to_hdr_scale(float scale); +#endif // STBI_NO_LINEAR + +// stbi_is_hdr is always defined, but always returns false if STBI_NO_HDR +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const * clbk, void * user); +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const * buffer, int len); +#ifndef STBI_NO_STDIO +STBIDEF int stbi_is_hdr(char const * filename); +STBIDEF int stbi_is_hdr_from_file(FILE * f); +#endif // STBI_NO_STDIO + +// get a VERY brief reason for failure +// on most compilers (and ALL modern mainstream compilers) this is threadsafe +STBIDEF const char * stbi_failure_reason(void); + +// free the loaded image -- this is just free() +STBIDEF void stbi_image_free(void * retval_from_stbi_load); + +// get image dimensions & components without fully decoding +STBIDEF int stbi_info_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * comp); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const * buffer, int len); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const * clbk, void * user); + +#ifndef STBI_NO_STDIO +STBIDEF int stbi_info(char const * filename, int * x, int * y, int * comp); +STBIDEF int stbi_info_from_file(FILE * f, int * x, int * y, int * comp); +STBIDEF int stbi_is_16_bit(char const * filename); +STBIDEF int stbi_is_16_bit_from_file(FILE * f); +#endif + +// for image formats that explicitly notate that they have premultiplied alpha, +// we just return the colors as stored in the file. set this flag to force +// unpremultiplication. results are undefined if the unpremultiply overflow. +STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply); + +// indicate whether we should process iphone images back to canonical format, +// or just pass them through "as-is" +STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert); + +// flip the image vertically, so the first pixel in the output array is the bottom left +STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip); + +// as above, but only applies to images loaded on the thread that calls the function +// this function is only available if your compiler supports thread-local variables; +// calling it will fail to link if your compiler doesn't +STBIDEF void stbi_set_unpremultiply_on_load_thread(int flag_true_if_should_unpremultiply); +STBIDEF void stbi_convert_iphone_png_to_rgb_thread(int flag_true_if_should_convert); +STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip); + +// ZLIB client - used by PNG, available for other purposes + +STBIDEF char * stbi_zlib_decode_malloc_guesssize(const char * buffer, int len, int initial_size, int * outlen); +STBIDEF char * stbi_zlib_decode_malloc_guesssize_headerflag(const char * buffer, int len, int initial_size, int * outlen, + int parse_header); +STBIDEF char * stbi_zlib_decode_malloc(const char * buffer, int len, int * outlen); +STBIDEF int stbi_zlib_decode_buffer(char * obuffer, int olen, const char * ibuffer, int ilen); + +STBIDEF char * stbi_zlib_decode_noheader_malloc(const char * buffer, int len, int * outlen); +STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const char * ibuffer, int ilen); + +#ifdef __cplusplus +} +#endif + +// +// +//// end header file ///////////////////////////////////////////////////// +#endif // STBI_INCLUDE_STB_IMAGE_H + +#ifdef STB_IMAGE_IMPLEMENTATION + +#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || defined(STBI_ONLY_BMP) || defined(STBI_ONLY_TGA) || \ + defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) || defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || \ + defined(STBI_ONLY_PNM) || defined(STBI_ONLY_ZLIB) +#ifndef STBI_ONLY_JPEG +#define STBI_NO_JPEG +#endif +#ifndef STBI_ONLY_PNG +#define STBI_NO_PNG +#endif +#ifndef STBI_ONLY_BMP +#define STBI_NO_BMP +#endif +#ifndef STBI_ONLY_PSD +#define STBI_NO_PSD +#endif +#ifndef STBI_ONLY_TGA +#define STBI_NO_TGA +#endif +#ifndef STBI_ONLY_GIF +#define STBI_NO_GIF +#endif +#ifndef STBI_ONLY_HDR +#define STBI_NO_HDR +#endif +#ifndef STBI_ONLY_PIC +#define STBI_NO_PIC +#endif +#ifndef STBI_ONLY_PNM +#define STBI_NO_PNM +#endif +#endif + +#if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && !defined(STBI_NO_ZLIB) +#define STBI_NO_ZLIB +#endif + +#include +#include +#include // ptrdiff_t on osx +#include +#include + +#if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) +#include // ldexp, pow +#endif + +#ifndef STBI_NO_STDIO +#include +#endif + +#ifndef STBI_ASSERT +#include +#define STBI_ASSERT(x) assert(x) +#endif + +#ifdef __cplusplus +#define STBI_EXTERN extern "C" +#else +#define STBI_EXTERN extern +#endif + +#ifndef _MSC_VER +#ifdef __cplusplus +#define stbi_inline inline +#else +#define stbi_inline +#endif +#else +#define stbi_inline __forceinline +#endif + +#ifndef STBI_NO_THREAD_LOCALS +#if defined(__cplusplus) && __cplusplus >= 201103L +#define STBI_THREAD_LOCAL thread_local +#elif defined(__GNUC__) && __GNUC__ < 5 +#define STBI_THREAD_LOCAL __thread +#elif defined(_MSC_VER) +#define STBI_THREAD_LOCAL __declspec(thread) +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_THREADS__) +#define STBI_THREAD_LOCAL _Thread_local +#endif + +#ifndef STBI_THREAD_LOCAL +#if defined(__GNUC__) +#define STBI_THREAD_LOCAL __thread +#endif +#endif +#endif + +#if defined(_MSC_VER) || defined(__SYMBIAN32__) +typedef unsigned short stbi__uint16; +typedef signed short stbi__int16; +typedef unsigned int stbi__uint32; +typedef signed int stbi__int32; +#else +#include +typedef uint16_t stbi__uint16; +typedef int16_t stbi__int16; +typedef uint32_t stbi__uint32; +typedef int32_t stbi__int32; +#endif + +// should produce compiler error if size is wrong +typedef unsigned char validate_uint32[sizeof(stbi__uint32) == 4 ? 1 : -1]; + +#ifdef _MSC_VER +#define STBI_NOTUSED(v) (void)(v) +#else +#define STBI_NOTUSED(v) (void)sizeof(v) +#endif + +#ifdef _MSC_VER +#define STBI_HAS_LROTL +#endif + +#ifdef STBI_HAS_LROTL +#define stbi_lrot(x, y) _lrotl(x, y) +#else +#define stbi_lrot(x, y) (((x) << (y)) | ((x) >> (-(y)&31))) +#endif + +#if defined(STBI_MALLOC) && defined(STBI_FREE) && (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED)) +// ok +#elif !defined(STBI_MALLOC) && !defined(STBI_FREE) && !defined(STBI_REALLOC) && !defined(STBI_REALLOC_SIZED) +// ok +#else +#error "Must define all or none of STBI_MALLOC, STBI_FREE, and STBI_REALLOC (or STBI_REALLOC_SIZED)." +#endif + +#ifndef STBI_MALLOC +#define STBI_MALLOC(sz) malloc(sz) +#define STBI_REALLOC(p, newsz) realloc(p, newsz) +#define STBI_FREE(p) free(p) +#endif + +#ifndef STBI_REALLOC_SIZED +#define STBI_REALLOC_SIZED(p, oldsz, newsz) STBI_REALLOC(p, newsz) +#endif + +// x86/x64 detection +#if defined(__x86_64__) || defined(_M_X64) +#define STBI__X64_TARGET +#elif defined(__i386) || defined(_M_IX86) +#define STBI__X86_TARGET +#endif + +#if defined(__GNUC__) && defined(STBI__X86_TARGET) && !defined(__SSE2__) && !defined(STBI_NO_SIMD) +// gcc doesn't support sse2 intrinsics unless you compile with -msse2, +// which in turn means it gets to use SSE2 everywhere. This is unfortunate, +// but previous attempts to provide the SSE2 functions with runtime +// detection caused numerous issues. The way architecture extensions are +// exposed in GCC/Clang is, sadly, not really suited for one-file libs. +// New behavior: if compiled with -msse2, we use SSE2 without any +// detection; if not, we don't use it at all. +#define STBI_NO_SIMD +#endif + +#if defined(__MINGW32__) && defined(STBI__X86_TARGET) && !defined(STBI_MINGW_ENABLE_SSE2) && !defined(STBI_NO_SIMD) +// Note that __MINGW32__ doesn't actually mean 32-bit, so we have to avoid STBI__X64_TARGET +// +// 32-bit MinGW wants ESP to be 16-byte aligned, but this is not in the +// Windows ABI and VC++ as well as Windows DLLs don't maintain that invariant. +// As a result, enabling SSE2 on 32-bit MinGW is dangerous when not +// simultaneously enabling "-mstackrealign". +// +// See https://github.com/nothings/stb/issues/81 for more information. +// +// So default to no SSE2 on 32-bit MinGW. If you've read this far and added +// -mstackrealign to your build settings, feel free to #define STBI_MINGW_ENABLE_SSE2. +#define STBI_NO_SIMD +#endif + +#if !defined(STBI_NO_SIMD) && (defined(STBI__X86_TARGET) || defined(STBI__X64_TARGET)) +#define STBI_SSE2 +#include + +#ifdef _MSC_VER + +#if _MSC_VER >= 1400 // not VC6 +#include // __cpuid +static int stbi__cpuid3(void) { + int info[4]; + __cpuid(info, 1); + return info[3]; +} +#else +static int stbi__cpuid3(void) { + int res; + __asm { + mov eax,1 + cpuid + mov res,edx + } + return res; +} +#endif + +#define STBI_SIMD_ALIGN(type, name) __declspec(align(16)) type name + +#if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) +static int stbi__sse2_available(void) { + int info3 = stbi__cpuid3(); + return ((info3 >> 26) & 1) != 0; +} +#endif + +#else // assume GCC-style if not VC++ +#define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16))) + +#if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) +static int stbi__sse2_available(void) { + // If we're even attempting to compile this on GCC/Clang, that means + // -msse2 is on, which means the compiler is allowed to use SSE2 + // instructions at will, and so are we. + return 1; +} +#endif + +#endif +#endif + +// ARM NEON +#if defined(STBI_NO_SIMD) && defined(STBI_NEON) +#undef STBI_NEON +#endif + +#ifdef STBI_NEON +#include +#ifdef _MSC_VER +#define STBI_SIMD_ALIGN(type, name) __declspec(align(16)) type name +#else +#define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16))) +#endif +#endif + +#ifndef STBI_SIMD_ALIGN +#define STBI_SIMD_ALIGN(type, name) type name +#endif + +#ifndef STBI_MAX_DIMENSIONS +#define STBI_MAX_DIMENSIONS (1 << 24) +#endif + +/////////////////////////////////////////////// +// +// stbi__context struct and start_xxx functions + +// stbi__context structure is our basic context used by all images, so it +// contains all the IO context, plus some basic image information +typedef struct { + stbi__uint32 img_x, img_y; + int img_n, img_out_n; + + stbi_io_callbacks io; + void * io_user_data; + + int read_from_callbacks; + int buflen; + stbi_uc buffer_start[128]; + int callback_already_read; + + stbi_uc *img_buffer, *img_buffer_end; + stbi_uc *img_buffer_original, *img_buffer_original_end; +} stbi__context; + +static void stbi__refill_buffer(stbi__context * s); + +// initialize a memory-decode context +static void stbi__start_mem(stbi__context * s, stbi_uc const * buffer, int len) { + s->io.read = NULL; + s->read_from_callbacks = 0; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = (stbi_uc *)buffer; + s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *)buffer + len; +} + +// initialize a callback-based context +static void stbi__start_callbacks(stbi__context * s, stbi_io_callbacks * c, void * user) { + s->io = *c; + s->io_user_data = user; + s->buflen = sizeof(s->buffer_start); + s->read_from_callbacks = 1; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = s->buffer_start; + stbi__refill_buffer(s); + s->img_buffer_original_end = s->img_buffer_end; +} + +#ifndef STBI_NO_STDIO + +static int stbi__stdio_read(void * user, char * data, int size) { return (int)fread(data, 1, size, (FILE *)user); } + +static void stbi__stdio_skip(void * user, int n) { + int ch; + fseek((FILE *)user, n, SEEK_CUR); + ch = fgetc((FILE *)user); /* have to read a byte to reset feof()'s flag */ + if (ch != EOF) { + ungetc(ch, (FILE *)user); /* push byte back onto stream if valid. */ + } +} + +static int stbi__stdio_eof(void * user) { return feof((FILE *)user) || ferror((FILE *)user); } + +static stbi_io_callbacks stbi__stdio_callbacks = { + stbi__stdio_read, + stbi__stdio_skip, + stbi__stdio_eof, +}; + +static void stbi__start_file(stbi__context * s, FILE * f) { stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *)f); } + +// static void stop_file(stbi__context *s) { } + +#endif // !STBI_NO_STDIO + +static void stbi__rewind(stbi__context * s) { + // conceptually rewind SHOULD rewind to the beginning of the stream, + // but we just rewind to the beginning of the initial buffer, because + // we only use it after doing 'test', which only ever looks at at most 92 bytes + s->img_buffer = s->img_buffer_original; + s->img_buffer_end = s->img_buffer_original_end; +} + +enum { STBI_ORDER_RGB, STBI_ORDER_BGR }; + +typedef struct { + int bits_per_channel; + int num_channels; + int channel_order; +} stbi__result_info; + +#ifndef STBI_NO_JPEG +static int stbi__jpeg_test(stbi__context * s); +static void * stbi__jpeg_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); +static int stbi__jpeg_info(stbi__context * s, int * x, int * y, int * comp); +#endif + +#ifndef STBI_NO_PNG +static int stbi__png_test(stbi__context * s); +static void * stbi__png_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); +static int stbi__png_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__png_is16(stbi__context * s); +#endif + +#ifndef STBI_NO_BMP +static int stbi__bmp_test(stbi__context * s); +static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); +static int stbi__bmp_info(stbi__context * s, int * x, int * y, int * comp); +#endif + +#ifndef STBI_NO_TGA +static int stbi__tga_test(stbi__context * s); +static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); +static int stbi__tga_info(stbi__context * s, int * x, int * y, int * comp); +#endif + +#ifndef STBI_NO_PSD +static int stbi__psd_test(stbi__context * s); +static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri, int bpc); +static int stbi__psd_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__psd_is16(stbi__context * s); +#endif + +#ifndef STBI_NO_HDR +static int stbi__hdr_test(stbi__context * s); +static float * stbi__hdr_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); +static int stbi__hdr_info(stbi__context * s, int * x, int * y, int * comp); +#endif + +#ifndef STBI_NO_PIC +static int stbi__pic_test(stbi__context * s); +static void * stbi__pic_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); +static int stbi__pic_info(stbi__context * s, int * x, int * y, int * comp); +#endif + +#ifndef STBI_NO_GIF +static int stbi__gif_test(stbi__context * s); +static void * stbi__gif_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); +static void * stbi__load_gif_main(stbi__context * s, int ** delays, int * x, int * y, int * z, int * comp, int req_comp); +static int stbi__gif_info(stbi__context * s, int * x, int * y, int * comp); +#endif + +#ifndef STBI_NO_PNM +static int stbi__pnm_test(stbi__context * s); +static void * stbi__pnm_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri); +static int stbi__pnm_info(stbi__context * s, int * x, int * y, int * comp); +static int stbi__pnm_is16(stbi__context * s); +#endif + +static +#ifdef STBI_THREAD_LOCAL + STBI_THREAD_LOCAL +#endif + const char * stbi__g_failure_reason; + +STBIDEF const char * stbi_failure_reason(void) { return stbi__g_failure_reason; } + +#ifndef STBI_NO_FAILURE_STRINGS +static int stbi__err(const char * str) { + stbi__g_failure_reason = str; + return 0; +} +#endif + +static void * stbi__malloc(size_t size) { return STBI_MALLOC(size); } + +// stb_image uses ints pervasively, including for offset calculations. +// therefore the largest decoded image size we can support with the +// current code, even on 64-bit targets, is INT_MAX. this is not a +// significant limitation for the intended use case. +// +// we do, however, need to make sure our size calculations don't +// overflow. hence a few helper functions for size calculations that +// multiply integers together, making sure that they're non-negative +// and no overflow occurs. + +// return 1 if the sum is valid, 0 on overflow. +// negative terms are considered invalid. +static int stbi__addsizes_valid(int a, int b) { + if (b < 0) + return 0; + // now 0 <= b <= INT_MAX, hence also + // 0 <= INT_MAX - b <= INTMAX. + // And "a + b <= INT_MAX" (which might overflow) is the + // same as a <= INT_MAX - b (no overflow) + return a <= INT_MAX - b; +} + +// returns 1 if the product is valid, 0 on overflow. +// negative factors are considered invalid. +static int stbi__mul2sizes_valid(int a, int b) { + if (a < 0 || b < 0) + return 0; + if (b == 0) + return 1; // mul-by-0 is always safe + // portable way to check for no overflows in a*b + return a <= INT_MAX / b; +} + +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +// returns 1 if "a*b + add" has no negative terms/factors and doesn't overflow +static int stbi__mad2sizes_valid(int a, int b, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a * b, add); +} +#endif + +// returns 1 if "a*b*c + add" has no negative terms/factors and doesn't overflow +static int stbi__mad3sizes_valid(int a, int b, int c, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && stbi__addsizes_valid(a * b * c, add); +} + +// returns 1 if "a*b*c*d + add" has no negative terms/factors and doesn't overflow +#if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) || !defined(STBI_NO_PNM) +static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) { + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a * b, c) && stbi__mul2sizes_valid(a * b * c, d) && + stbi__addsizes_valid(a * b * c * d, add); +} +#endif + +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +// mallocs with size overflow checking +static void * stbi__malloc_mad2(int a, int b, int add) { + if (!stbi__mad2sizes_valid(a, b, add)) + return NULL; + return stbi__malloc(a * b + add); +} +#endif + +static void * stbi__malloc_mad3(int a, int b, int c, int add) { + if (!stbi__mad3sizes_valid(a, b, c, add)) + return NULL; + return stbi__malloc(a * b * c + add); +} + +#if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) || !defined(STBI_NO_PNM) +static void * stbi__malloc_mad4(int a, int b, int c, int d, int add) { + if (!stbi__mad4sizes_valid(a, b, c, d, add)) + return NULL; + return stbi__malloc(a * b * c * d + add); +} +#endif + +// returns 1 if the sum of two signed ints is valid (between -2^31 and 2^31-1 inclusive), 0 on overflow. +static int stbi__addints_valid(int a, int b) { + if ((a >= 0) != (b >= 0)) + return 1; // a and b have different signs, so no overflow + if (a < 0 && b < 0) + return a >= INT_MIN - b; // same as a + b >= INT_MIN; INT_MIN - b cannot overflow since b < 0. + return a <= INT_MAX - b; +} + +// returns 1 if the product of two signed shorts is valid, 0 on overflow. +static int stbi__mul2shorts_valid(short a, short b) { + if (b == 0 || b == -1) + return 1; // multiplication by 0 is always 0; check for -1 so SHRT_MIN/b doesn't overflow + if ((a >= 0) == (b >= 0)) + return a <= SHRT_MAX / b; // product is positive, so similar to mul2sizes_valid + if (b < 0) + return a <= SHRT_MIN / b; // same as a * b >= SHRT_MIN + return a >= SHRT_MIN / b; +} + +// stbi__err - error +// stbi__errpf - error returning pointer to float +// stbi__errpuc - error returning pointer to unsigned char + +#ifdef STBI_NO_FAILURE_STRINGS +#define stbi__err(x, y) 0 +#elif defined(STBI_FAILURE_USERMSG) +#define stbi__err(x, y) stbi__err(y) +#else +#define stbi__err(x, y) stbi__err(x) +#endif + +#define stbi__errpf(x, y) ((float *)(size_t)(stbi__err(x, y) ? NULL : NULL)) +#define stbi__errpuc(x, y) ((unsigned char *)(size_t)(stbi__err(x, y) ? NULL : NULL)) + +STBIDEF void stbi_image_free(void * retval_from_stbi_load) { STBI_FREE(retval_from_stbi_load); } + +#ifndef STBI_NO_LINEAR +static float * stbi__ldr_to_hdr(stbi_uc * data, int x, int y, int comp); +#endif + +#ifndef STBI_NO_HDR +static stbi_uc * stbi__hdr_to_ldr(float * data, int x, int y, int comp); +#endif + +static int stbi__vertically_flip_on_load_global = 0; + +STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) { + stbi__vertically_flip_on_load_global = flag_true_if_should_flip; +} + +#ifndef STBI_THREAD_LOCAL +#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global +#else +static STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, stbi__vertically_flip_on_load_set; + +STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) { + stbi__vertically_flip_on_load_local = flag_true_if_should_flip; + stbi__vertically_flip_on_load_set = 1; +} + +#define stbi__vertically_flip_on_load \ + (stbi__vertically_flip_on_load_set ? stbi__vertically_flip_on_load_local : stbi__vertically_flip_on_load_global) +#endif // STBI_THREAD_LOCAL + +static void * stbi__load_main(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri, int bpc) { + memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields + ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed + ri->channel_order = STBI_ORDER_RGB; // all current input & output are this, but this is here so we can add BGR order + ri->num_channels = 0; + +// test the formats with a very explicit header first (at least a FOURCC +// or distinctive magic number first) +#ifndef STBI_NO_PNG + if (stbi__png_test(s)) + return stbi__png_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_BMP + if (stbi__bmp_test(s)) + return stbi__bmp_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_GIF + if (stbi__gif_test(s)) + return stbi__gif_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PSD + if (stbi__psd_test(s)) + return stbi__psd_load(s, x, y, comp, req_comp, ri, bpc); +#else + STBI_NOTUSED(bpc); +#endif +#ifndef STBI_NO_PIC + if (stbi__pic_test(s)) + return stbi__pic_load(s, x, y, comp, req_comp, ri); +#endif + +// then the formats that can end up attempting to load with just 1 or 2 +// bytes matching expectations; these are prone to false positives, so +// try them later +#ifndef STBI_NO_JPEG + if (stbi__jpeg_test(s)) + return stbi__jpeg_load(s, x, y, comp, req_comp, ri); +#endif +#ifndef STBI_NO_PNM + if (stbi__pnm_test(s)) + return stbi__pnm_load(s, x, y, comp, req_comp, ri); +#endif + +#ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + float * hdr = stbi__hdr_load(s, x, y, comp, req_comp, ri); + return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); + } +#endif + +#ifndef STBI_NO_TGA + // test tga last because it's a crappy test! + if (stbi__tga_test(s)) + return stbi__tga_load(s, x, y, comp, req_comp, ri); +#endif + + return stbi__errpuc("unknown image type", "Image not of any known type, or corrupt"); +} + +static stbi_uc * stbi__convert_16_to_8(stbi__uint16 * orig, int w, int h, int channels) { + int i; + int img_len = w * h * channels; + stbi_uc * reduced; + + reduced = (stbi_uc *)stbi__malloc(img_len); + if (reduced == NULL) + return stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + reduced[i] = (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient approx of 16->8 bit scaling + + STBI_FREE(orig); + return reduced; +} + +static stbi__uint16 * stbi__convert_8_to_16(stbi_uc * orig, int w, int h, int channels) { + int i; + int img_len = w * h * channels; + stbi__uint16 * enlarged; + + enlarged = (stbi__uint16 *)stbi__malloc(img_len * 2); + if (enlarged == NULL) + return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + enlarged[i] = (stbi__uint16)((orig[i] << 8) + orig[i]); // replicate to high and low byte, maps 0->0, 255->0xffff + + STBI_FREE(orig); + return enlarged; +} + +static void stbi__vertical_flip(void * image, int w, int h, int bytes_per_pixel) { + int row; + size_t bytes_per_row = (size_t)w * bytes_per_pixel; + stbi_uc temp[2048]; + stbi_uc * bytes = (stbi_uc *)image; + + for (row = 0; row < (h >> 1); row++) { + stbi_uc * row0 = bytes + row * bytes_per_row; + stbi_uc * row1 = bytes + (h - row - 1) * bytes_per_row; + // swap row0 with row1 + size_t bytes_left = bytes_per_row; + while (bytes_left) { + size_t bytes_copy = (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); + memcpy(temp, row0, bytes_copy); + memcpy(row0, row1, bytes_copy); + memcpy(row1, temp, bytes_copy); + row0 += bytes_copy; + row1 += bytes_copy; + bytes_left -= bytes_copy; + } + } +} + +#ifndef STBI_NO_GIF +static void stbi__vertical_flip_slices(void * image, int w, int h, int z, int bytes_per_pixel) { + int slice; + int slice_size = w * h * bytes_per_pixel; + + stbi_uc * bytes = (stbi_uc *)image; + for (slice = 0; slice < z; ++slice) { + stbi__vertical_flip(bytes, w, h, bytes_per_pixel); + bytes += slice_size; + } +} +#endif + +static unsigned char * stbi__load_and_postprocess_8bit(stbi__context * s, int * x, int * y, int * comp, int req_comp) { + stbi__result_info ri; + void * result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); + + if (result == NULL) + return NULL; + + // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + + if (ri.bits_per_channel != 8) { + result = stbi__convert_16_to_8((stbi__uint16 *)result, *x, *y, req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 8; + } + + // @TODO: move stbi__convert_format to here + + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); + } + + return (unsigned char *)result; +} + +static stbi__uint16 * stbi__load_and_postprocess_16bit(stbi__context * s, int * x, int * y, int * comp, int req_comp) { + stbi__result_info ri; + void * result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); + + if (result == NULL) + return NULL; + + // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + + if (ri.bits_per_channel != 16) { + result = stbi__convert_8_to_16((stbi_uc *)result, *x, *y, req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 16; + } + + // @TODO: move stbi__convert_format16 to here + // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to keep more precision + + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); + } + + return (stbi__uint16 *)result; +} + +#if !defined(STBI_NO_HDR) && !defined(STBI_NO_LINEAR) +static void stbi__float_postprocess(float * result, int * x, int * y, int * comp, int req_comp) { + if (stbi__vertically_flip_on_load && result != NULL) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); + } +} +#endif + +#ifndef STBI_NO_STDIO + +#if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) +STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char * str, + int cbmb, wchar_t * widestr, int cchwide); +STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, + const wchar_t * widestr, int cchwide, char * str, int cbmb, + const char * defchar, int * used_default); +#endif + +#if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) +STBIDEF int stbi_convert_wchar_to_utf8(char * buffer, size_t bufferlen, const wchar_t * input) { + return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int)bufferlen, NULL, NULL); +} +#endif + +static FILE * stbi__fopen(char const * filename, char const * mode) { + FILE * f; +#if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) + wchar_t wMode[64]; + wchar_t wFilename[1024]; + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename) / sizeof(*wFilename))) + return 0; + + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode) / sizeof(*wMode))) + return 0; + +#if defined(_MSC_VER) && _MSC_VER >= 1400 + if (0 != _wfopen_s(&f, wFilename, wMode)) + f = 0; +#else + f = _wfopen(wFilename, wMode); +#endif + +#elif defined(_MSC_VER) && _MSC_VER >= 1400 + if (0 != fopen_s(&f, filename, mode)) + f = 0; +#else + f = fopen(filename, mode); +#endif + return f; +} + +STBIDEF stbi_uc * stbi_load(char const * filename, int * x, int * y, int * comp, int req_comp) { + FILE * f = stbi__fopen(filename, "rb"); + unsigned char * result; + if (!f) + return stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +STBIDEF stbi_uc * stbi_load_from_file(FILE * f, int * x, int * y, int * comp, int req_comp) { + unsigned char * result; + stbi__context s; + stbi__start_file(&s, f); + result = stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi__uint16 * stbi_load_from_file_16(FILE * f, int * x, int * y, int * comp, int req_comp) { + stbi__uint16 * result; + stbi__context s; + stbi__start_file(&s, f); + result = stbi__load_and_postprocess_16bit(&s, x, y, comp, req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, -(int)(s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi_us * stbi_load_16(char const * filename, int * x, int * y, int * comp, int req_comp) { + FILE * f = stbi__fopen(filename, "rb"); + stbi__uint16 * result; + if (!f) + return (stbi_us *)stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file_16(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +#endif //! STBI_NO_STDIO + +STBIDEF stbi_us * stbi_load_16_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * channels_in_file, + int desired_channels) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, desired_channels); +} + +STBIDEF stbi_us * stbi_load_16_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, + int * channels_in_file, int desired_channels) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_16bit(&s, x, y, channels_in_file, desired_channels); +} + +STBIDEF stbi_uc * stbi_load_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp, int req_comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); +} + +STBIDEF stbi_uc * stbi_load_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * comp, + int req_comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_8bit(&s, x, y, comp, req_comp); +} + +#ifndef STBI_NO_GIF +STBIDEF stbi_uc * stbi_load_gif_from_memory(stbi_uc const * buffer, int len, int ** delays, int * x, int * y, int * z, + int * comp, int req_comp) { + unsigned char * result; + stbi__context s; + stbi__start_mem(&s, buffer, len); + + result = (unsigned char *)stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); + if (stbi__vertically_flip_on_load) { + stbi__vertical_flip_slices(result, *x, *y, *z, *comp); + } + + return result; +} +#endif + +#ifndef STBI_NO_LINEAR +static float * stbi__loadf_main(stbi__context * s, int * x, int * y, int * comp, int req_comp) { + unsigned char * data; +#ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + stbi__result_info ri; + float * hdr_data = stbi__hdr_load(s, x, y, comp, req_comp, &ri); + if (hdr_data) + stbi__float_postprocess(hdr_data, x, y, comp, req_comp); + return hdr_data; + } +#endif + data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); + if (data) + return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); + return stbi__errpf("unknown image type", "Image not of any known type, or corrupt"); +} + +STBIDEF float * stbi_loadf_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp, int req_comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__loadf_main(&s, x, y, comp, req_comp); +} + +STBIDEF float * stbi_loadf_from_callbacks(stbi_io_callbacks const * clbk, void * user, int * x, int * y, int * comp, + int req_comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__loadf_main(&s, x, y, comp, req_comp); +} + +#ifndef STBI_NO_STDIO +STBIDEF float * stbi_loadf(char const * filename, int * x, int * y, int * comp, int req_comp) { + float * result; + FILE * f = stbi__fopen(filename, "rb"); + if (!f) + return stbi__errpf("can't fopen", "Unable to open file"); + result = stbi_loadf_from_file(f, x, y, comp, req_comp); + fclose(f); + return result; +} + +STBIDEF float * stbi_loadf_from_file(FILE * f, int * x, int * y, int * comp, int req_comp) { + stbi__context s; + stbi__start_file(&s, f); + return stbi__loadf_main(&s, x, y, comp, req_comp); +} +#endif // !STBI_NO_STDIO + +#endif // !STBI_NO_LINEAR + +// these is-hdr-or-not is defined independent of whether STBI_NO_LINEAR is +// defined, for API simplicity; if STBI_NO_LINEAR is defined, it always +// reports false! + +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const * buffer, int len) { +#ifndef STBI_NO_HDR + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__hdr_test(&s); +#else + STBI_NOTUSED(buffer); + STBI_NOTUSED(len); + return 0; +#endif +} + +#ifndef STBI_NO_STDIO +STBIDEF int stbi_is_hdr(char const * filename) { + FILE * f = stbi__fopen(filename, "rb"); + int result = 0; + if (f) { + result = stbi_is_hdr_from_file(f); + fclose(f); + } + return result; +} + +STBIDEF int stbi_is_hdr_from_file(FILE * f) { +#ifndef STBI_NO_HDR + long pos = ftell(f); + int res; + stbi__context s; + stbi__start_file(&s, f); + res = stbi__hdr_test(&s); + fseek(f, pos, SEEK_SET); + return res; +#else + STBI_NOTUSED(f); + return 0; +#endif +} +#endif // !STBI_NO_STDIO + +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const * clbk, void * user) { +#ifndef STBI_NO_HDR + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__hdr_test(&s); +#else + STBI_NOTUSED(clbk); + STBI_NOTUSED(user); + return 0; +#endif +} + +#ifndef STBI_NO_LINEAR +static float stbi__l2h_gamma = 2.2f, stbi__l2h_scale = 1.0f; + +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { stbi__l2h_gamma = gamma; } +STBIDEF void stbi_ldr_to_hdr_scale(float scale) { stbi__l2h_scale = scale; } +#endif + +static float stbi__h2l_gamma_i = 1.0f / 2.2f, stbi__h2l_scale_i = 1.0f; + +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { stbi__h2l_gamma_i = 1 / gamma; } +STBIDEF void stbi_hdr_to_ldr_scale(float scale) { stbi__h2l_scale_i = 1 / scale; } + +////////////////////////////////////////////////////////////////////////////// +// +// Common code used by all image loaders +// + +enum { STBI__SCAN_load = 0, STBI__SCAN_type, STBI__SCAN_header }; + +static void stbi__refill_buffer(stbi__context * s) { + int n = (s->io.read)(s->io_user_data, (char *)s->buffer_start, s->buflen); + s->callback_already_read += (int)(s->img_buffer - s->img_buffer_original); + if (n == 0) { + // at end of file, treat same as if from memory, but need to handle case + // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file + s->read_from_callbacks = 0; + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + 1; + *s->img_buffer = 0; + } else { + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + n; + } +} + +stbi_inline static stbi_uc stbi__get8(stbi__context * s) { + if (s->img_buffer < s->img_buffer_end) + return *s->img_buffer++; + if (s->read_from_callbacks) { + stbi__refill_buffer(s); + return *s->img_buffer++; + } + return 0; +} + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +// nothing +#else +stbi_inline static int stbi__at_eof(stbi__context * s) { + if (s->io.read) { + if (!(s->io.eof)(s->io_user_data)) + return 0; + // if feof() is true, check if buffer = end + // special case: we've only got the special 0 character at the end + if (s->read_from_callbacks == 0) + return 1; + } + + return s->img_buffer >= s->img_buffer_end; +} +#endif + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && \ + defined(STBI_NO_GIF) && defined(STBI_NO_PIC) +// nothing +#else +static void stbi__skip(stbi__context * s, int n) { + if (n == 0) + return; // already there! + if (n < 0) { + s->img_buffer = s->img_buffer_end; + return; + } + if (s->io.read) { + int blen = (int)(s->img_buffer_end - s->img_buffer); + if (blen < n) { + s->img_buffer = s->img_buffer_end; + (s->io.skip)(s->io_user_data, n - blen); + return; + } + } + s->img_buffer += n; +} +#endif + +#if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && defined(STBI_NO_PNM) +// nothing +#else +static int stbi__getn(stbi__context * s, stbi_uc * buffer, int n) { + if (s->io.read) { + int blen = (int)(s->img_buffer_end - s->img_buffer); + if (blen < n) { + int res, count; + + memcpy(buffer, s->img_buffer, blen); + + count = (s->io.read)(s->io_user_data, (char *)buffer + blen, n - blen); + res = (count == (n - blen)); + s->img_buffer = s->img_buffer_end; + return res; + } + } + + if (s->img_buffer + n <= s->img_buffer_end) { + memcpy(buffer, s->img_buffer, n); + s->img_buffer += n; + return 1; + } else + return 0; +} +#endif + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) +// nothing +#else +static int stbi__get16be(stbi__context * s) { + int z = stbi__get8(s); + return (z << 8) + stbi__get8(s); +} +#endif + +#if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) +// nothing +#else +static stbi__uint32 stbi__get32be(stbi__context * s) { + stbi__uint32 z = stbi__get16be(s); + return (z << 16) + stbi__get16be(s); +} +#endif + +#if defined(STBI_NO_BMP) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) +// nothing +#else +static int stbi__get16le(stbi__context * s) { + int z = stbi__get8(s); + return z + (stbi__get8(s) << 8); +} +#endif + +#ifndef STBI_NO_BMP +static stbi__uint32 stbi__get32le(stbi__context * s) { + stbi__uint32 z = stbi__get16le(s); + z += (stbi__uint32)stbi__get16le(s) << 16; + return z; +} +#endif + +#define STBI__BYTECAST(x) ((stbi_uc)((x)&255)) // truncate int to byte without warnings + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && \ + defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +// nothing +#else +////////////////////////////////////////////////////////////////////////////// +// +// generic converter from built-in img_n to req_comp +// individual types do this automatically as much as possible (e.g. jpeg +// does all cases internally since it needs to colorspace convert anyway, +// and it never has alpha, so very few cases ). png can automatically +// interleave an alpha=255 channel, but falls back to this for other cases +// +// assume data buffer is malloced, so malloc a new one and free that one +// only failure mode is malloc failing + +static stbi_uc stbi__compute_y(int r, int g, int b) { return (stbi_uc)(((r * 77) + (g * 150) + (29 * b)) >> 8); } +#endif + +#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && \ + defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +// nothing +#else +static unsigned char * stbi__convert_format(unsigned char * data, int img_n, int req_comp, unsigned int x, unsigned int y) { + int i, j; + unsigned char * good; + + if (req_comp == img_n) + return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (unsigned char *)stbi__malloc_mad3(req_comp, x, y, 0); + if (good == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } + + for (j = 0; j < (int)y; ++j) { + unsigned char * src = data + j * x * img_n; + unsigned char * dest = good + j * x * req_comp; + +#define STBI__COMBO(a, b) ((a)*8 + (b)) +#define STBI__CASE(a, b) \ + case STBI__COMBO(a, b): \ + for (i = x - 1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp components; + // avoid switch per pixel, so use switch per scanline and massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1, 2) { + dest[0] = src[0]; + dest[1] = 255; + } + break; + STBI__CASE(1, 3) { dest[0] = dest[1] = dest[2] = src[0]; } + break; + STBI__CASE(1, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = 255; + } + break; + STBI__CASE(2, 1) { dest[0] = src[0]; } + break; + STBI__CASE(2, 3) { dest[0] = dest[1] = dest[2] = src[0]; } + break; + STBI__CASE(2, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = src[1]; + } + break; + STBI__CASE(3, 4) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + dest[3] = 255; + } + break; + STBI__CASE(3, 1) { dest[0] = stbi__compute_y(src[0], src[1], src[2]); } + break; + STBI__CASE(3, 2) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + dest[1] = 255; + } + break; + STBI__CASE(4, 1) { dest[0] = stbi__compute_y(src[0], src[1], src[2]); } + break; + STBI__CASE(4, 2) { + dest[0] = stbi__compute_y(src[0], src[1], src[2]); + dest[1] = src[3]; + } + break; + STBI__CASE(4, 3) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + } + break; + default: + STBI_ASSERT(0); + STBI_FREE(data); + STBI_FREE(good); + return stbi__errpuc("unsupported", "Unsupported format conversion"); + } +#undef STBI__CASE + } + + STBI_FREE(data); + return good; +} +#endif + +#if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) +// nothing +#else +static stbi__uint16 stbi__compute_y_16(int r, int g, int b) { return (stbi__uint16)(((r * 77) + (g * 150) + (29 * b)) >> 8); } +#endif + +#if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) +// nothing +#else +static stbi__uint16 * stbi__convert_format16(stbi__uint16 * data, int img_n, int req_comp, unsigned int x, unsigned int y) { + int i, j; + stbi__uint16 * good; + + if (req_comp == img_n) + return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (stbi__uint16 *)stbi__malloc(req_comp * x * y * 2); + if (good == NULL) { + STBI_FREE(data); + return (stbi__uint16 *)stbi__errpuc("outofmem", "Out of memory"); + } + + for (j = 0; j < (int)y; ++j) { + stbi__uint16 * src = data + j * x * img_n; + stbi__uint16 * dest = good + j * x * req_comp; + +#define STBI__COMBO(a, b) ((a)*8 + (b)) +#define STBI__CASE(a, b) \ + case STBI__COMBO(a, b): \ + for (i = x - 1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp components; + // avoid switch per pixel, so use switch per scanline and massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1, 2) { + dest[0] = src[0]; + dest[1] = 0xffff; + } + break; + STBI__CASE(1, 3) { dest[0] = dest[1] = dest[2] = src[0]; } + break; + STBI__CASE(1, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = 0xffff; + } + break; + STBI__CASE(2, 1) { dest[0] = src[0]; } + break; + STBI__CASE(2, 3) { dest[0] = dest[1] = dest[2] = src[0]; } + break; + STBI__CASE(2, 4) { + dest[0] = dest[1] = dest[2] = src[0]; + dest[3] = src[1]; + } + break; + STBI__CASE(3, 4) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + dest[3] = 0xffff; + } + break; + STBI__CASE(3, 1) { dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); } + break; + STBI__CASE(3, 2) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + dest[1] = 0xffff; + } + break; + STBI__CASE(4, 1) { dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); } + break; + STBI__CASE(4, 2) { + dest[0] = stbi__compute_y_16(src[0], src[1], src[2]); + dest[1] = src[3]; + } + break; + STBI__CASE(4, 3) { + dest[0] = src[0]; + dest[1] = src[1]; + dest[2] = src[2]; + } + break; + default: + STBI_ASSERT(0); + STBI_FREE(data); + STBI_FREE(good); + return (stbi__uint16 *)stbi__errpuc("unsupported", "Unsupported format conversion"); + } +#undef STBI__CASE + } + + STBI_FREE(data); + return good; +} +#endif + +#ifndef STBI_NO_LINEAR +static float * stbi__ldr_to_hdr(stbi_uc * data, int x, int y, int comp) { + int i, k, n; + float * output; + if (!data) + return NULL; + output = (float *)stbi__malloc_mad4(x, y, comp, sizeof(float), 0); + if (output == NULL) { + STBI_FREE(data); + return stbi__errpf("outofmem", "Out of memory"); + } + // compute number of non-alpha components + if (comp & 1) + n = comp; + else + n = comp - 1; + for (i = 0; i < x * y; ++i) { + for (k = 0; k < n; ++k) { + output[i * comp + k] = (float)(pow(data[i * comp + k] / 255.0f, stbi__l2h_gamma) * stbi__l2h_scale); + } + } + if (n < comp) { + for (i = 0; i < x * y; ++i) { + output[i * comp + n] = data[i * comp + n] / 255.0f; + } + } + STBI_FREE(data); + return output; +} +#endif + +#ifndef STBI_NO_HDR +#define stbi__float2int(x) ((int)(x)) +static stbi_uc * stbi__hdr_to_ldr(float * data, int x, int y, int comp) { + int i, k, n; + stbi_uc * output; + if (!data) + return NULL; + output = (stbi_uc *)stbi__malloc_mad3(x, y, comp, 0); + if (output == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } + // compute number of non-alpha components + if (comp & 1) + n = comp; + else + n = comp - 1; + for (i = 0; i < x * y; ++i) { + for (k = 0; k < n; ++k) { + float z = (float)pow(data[i * comp + k] * stbi__h2l_scale_i, stbi__h2l_gamma_i) * 255 + 0.5f; + if (z < 0) + z = 0; + if (z > 255) + z = 255; + output[i * comp + k] = (stbi_uc)stbi__float2int(z); + } + if (k < comp) { + float z = data[i * comp + k] * 255 + 0.5f; + if (z < 0) + z = 0; + if (z > 255) + z = 255; + output[i * comp + k] = (stbi_uc)stbi__float2int(z); + } + } + STBI_FREE(data); + return output; +} +#endif + +////////////////////////////////////////////////////////////////////////////// +// +// "baseline" JPEG/JFIF decoder +// +// simple implementation +// - doesn't support delayed output of y-dimension +// - simple interface (only one output format: 8-bit interleaved RGB) +// - doesn't try to recover corrupt jpegs +// - doesn't allow partial loading, loading multiple at once +// - still fast on x86 (copying globals into locals doesn't help x86) +// - allocates lots of intermediate memory (full size of all components) +// - non-interleaved case requires this anyway +// - allows good upsampling (see next) +// high-quality +// - upsampled channels are bilinearly interpolated, even across blocks +// - quality integer IDCT derived from IJG's 'slow' +// performance +// - fast huffman; reasonable integer IDCT +// - some SIMD kernels for common paths on targets with SSE2/NEON +// - uses a lot of intermediate memory, could cache poorly + +#ifndef STBI_NO_JPEG + +// huffman decoding acceleration +#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache + +typedef struct { + stbi_uc fast[1 << FAST_BITS]; + // weirdly, repacking this into AoS is a 10% speed loss, instead of a win + stbi__uint16 code[256]; + stbi_uc values[256]; + stbi_uc size[257]; + unsigned int maxcode[18]; + int delta[17]; // old 'firstsymbol' - old 'firstcode' +} stbi__huffman; + +typedef struct { + stbi__context * s; + stbi__huffman huff_dc[4]; + stbi__huffman huff_ac[4]; + stbi__uint16 dequant[4][64]; + stbi__int16 fast_ac[4][1 << FAST_BITS]; + + // sizes for components, interleaved MCUs + int img_h_max, img_v_max; + int img_mcu_x, img_mcu_y; + int img_mcu_w, img_mcu_h; + + // definition of jpeg image component + struct { + int id; + int h, v; + int tq; + int hd, ha; + int dc_pred; + + int x, y, w2, h2; + stbi_uc * data; + void *raw_data, *raw_coeff; + stbi_uc * linebuf; + short * coeff; // progressive only + int coeff_w, coeff_h; // number of 8x8 coefficient blocks + } img_comp[4]; + + stbi__uint32 code_buffer; // jpeg entropy-coded buffer + int code_bits; // number of valid bits + unsigned char marker; // marker seen while filling entropy buffer + int nomore; // flag if we saw a marker so must stop + + int progressive; + int spec_start; + int spec_end; + int succ_high; + int succ_low; + int eob_run; + int jfif; + int app14_color_transform; // Adobe APP14 tag + int rgb; + + int scan_n, order[4]; + int restart_interval, todo; + + // kernels + void (*idct_block_kernel)(stbi_uc * out, int out_stride, short data[64]); + void (*YCbCr_to_RGB_kernel)(stbi_uc * out, const stbi_uc * y, const stbi_uc * pcb, const stbi_uc * pcr, int count, + int step); + stbi_uc * (*resample_row_hv_2_kernel)(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs); +} stbi__jpeg; + +static int stbi__build_huffman(stbi__huffman * h, int * count) { + int i, j, k = 0; + unsigned int code; + // build size list for each symbol (from JPEG spec) + for (i = 0; i < 16; ++i) { + for (j = 0; j < count[i]; ++j) { + h->size[k++] = (stbi_uc)(i + 1); + if (k >= 257) + return stbi__err("bad size list", "Corrupt JPEG"); + } + } + h->size[k] = 0; + + // compute actual symbols (from jpeg spec) + code = 0; + k = 0; + for (j = 1; j <= 16; ++j) { + // compute delta to add to code to compute symbol id + h->delta[j] = k - code; + if (h->size[k] == j) { + while (h->size[k] == j) + h->code[k++] = (stbi__uint16)(code++); + if (code - 1 >= (1u << j)) + return stbi__err("bad code lengths", "Corrupt JPEG"); + } + // compute largest code + 1 for this size, preshifted as needed later + h->maxcode[j] = code << (16 - j); + code <<= 1; + } + h->maxcode[j] = 0xffffffff; + + // build non-spec acceleration table; 255 is flag for not-accelerated + memset(h->fast, 255, 1 << FAST_BITS); + for (i = 0; i < k; ++i) { + int s = h->size[i]; + if (s <= FAST_BITS) { + int c = h->code[i] << (FAST_BITS - s); + int m = 1 << (FAST_BITS - s); + for (j = 0; j < m; ++j) { + h->fast[c + j] = (stbi_uc)i; + } + } + } + return 1; +} + +// build a table that decodes both magnitude and value of small ACs in +// one go. +static void stbi__build_fast_ac(stbi__int16 * fast_ac, stbi__huffman * h) { + int i; + for (i = 0; i < (1 << FAST_BITS); ++i) { + stbi_uc fast = h->fast[i]; + fast_ac[i] = 0; + if (fast < 255) { + int rs = h->values[fast]; + int run = (rs >> 4) & 15; + int magbits = rs & 15; + int len = h->size[fast]; + + if (magbits && len + magbits <= FAST_BITS) { + // magnitude code followed by receive_extend code + int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); + int m = 1 << (magbits - 1); + if (k < m) + k += (~0U << magbits) + 1; + // if the result is small enough, we can fit it in fast_ac table + if (k >= -128 && k <= 127) + fast_ac[i] = (stbi__int16)((k * 256) + (run * 16) + (len + magbits)); + } + } + } +} + +static void stbi__grow_buffer_unsafe(stbi__jpeg * j) { + do { + unsigned int b = j->nomore ? 0 : stbi__get8(j->s); + if (b == 0xff) { + int c = stbi__get8(j->s); + while (c == 0xff) + c = stbi__get8(j->s); // consume fill bytes + if (c != 0) { + j->marker = (unsigned char)c; + j->nomore = 1; + return; + } + } + j->code_buffer |= b << (24 - j->code_bits); + j->code_bits += 8; + } while (j->code_bits <= 24); +} + +// (1 << n) - 1 +static const stbi__uint32 stbi__bmask[17] = {0, 1, 3, 7, 15, 31, 63, 127, 255, + 511, 1023, 2047, 4095, 8191, 16383, 32767, 65535}; + +// decode a jpeg huffman value from the bitstream +stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg * j, stbi__huffman * h) { + unsigned int temp; + int c, k; + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + + // look at the top FAST_BITS and determine what symbol ID it is, + // if the code is <= FAST_BITS + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + k = h->fast[c]; + if (k < 255) { + int s = h->size[k]; + if (s > j->code_bits) + return -1; + j->code_buffer <<= s; + j->code_bits -= s; + return h->values[k]; + } + + // naive test is to shift the code_buffer down so k bits are + // valid, then test against maxcode. To speed this up, we've + // preshifted maxcode left so that it has (16-k) 0s at the + // end; in other words, regardless of the number of bits, it + // wants to be compared against something shifted to have 16; + // that way we don't need to shift inside the loop. + temp = j->code_buffer >> 16; + for (k = FAST_BITS + 1;; ++k) + if (temp < h->maxcode[k]) + break; + if (k == 17) { + // error! code not found + j->code_bits -= 16; + return -1; + } + + if (k > j->code_bits) + return -1; + + // convert the huffman code to the symbol id + c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; + if (c < 0 || c >= 256) // symbol id out of bounds! + return -1; + STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & stbi__bmask[h->size[c]]) == h->code[c]); + + // convert the id to a symbol + j->code_bits -= k; + j->code_buffer <<= k; + return h->values[c]; +} + +// bias[n] = (-1<code_bits < n) + stbi__grow_buffer_unsafe(j); + if (j->code_bits < n) + return 0; // ran out of bits from stream, return 0s intead of continuing + + sgn = j->code_buffer >> 31; // sign bit always in MSB; 0 if MSB clear (positive), 1 if MSB set (negative) + k = stbi_lrot(j->code_buffer, n); + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k + (stbi__jbias[n] & (sgn - 1)); +} + +// get some unsigned bits +stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg * j, int n) { + unsigned int k; + if (j->code_bits < n) + stbi__grow_buffer_unsafe(j); + if (j->code_bits < n) + return 0; // ran out of bits from stream, return 0s intead of continuing + k = stbi_lrot(j->code_buffer, n); + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k; +} + +stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg * j) { + unsigned int k; + if (j->code_bits < 1) + stbi__grow_buffer_unsafe(j); + if (j->code_bits < 1) + return 0; // ran out of bits from stream, return 0s intead of continuing + k = j->code_buffer; + j->code_buffer <<= 1; + --j->code_bits; + return k & 0x80000000; +} + +// given a value that's at position X in the zigzag stream, +// where does it appear in the 8x8 matrix coded as row-major? +static const stbi_uc stbi__jpeg_dezigzag[64 + 15] = { + 0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 12, 19, 26, 33, 40, 48, 41, 34, 27, 20, 13, 6, 7, 14, 21, 28, 35, + 42, 49, 56, 57, 50, 43, 36, 29, 22, 15, 23, 30, 37, 44, 51, 58, 59, 52, 45, 38, 31, 39, 46, 53, 60, 61, 54, 47, 55, 62, 63, + // let corrupt input sample past end + 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63}; + +// decode one 64-entry block-- +static int stbi__jpeg_decode_block(stbi__jpeg * j, short data[64], stbi__huffman * hdc, stbi__huffman * hac, stbi__int16 * fac, + int b, stbi__uint16 * dequant) { + int diff, dc, k; + int t; + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + t = stbi__jpeg_huff_decode(j, hdc); + if (t < 0 || t > 15) + return stbi__err("bad huffman code", "Corrupt JPEG"); + + // 0 all the ac values now so we can do it 32-bits at a time + memset(data, 0, 64 * sizeof(data[0])); + + diff = t ? stbi__extend_receive(j, t) : 0; + if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) + return stbi__err("bad delta", "Corrupt JPEG"); + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + if (!stbi__mul2shorts_valid(dc, dequant[0])) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + data[0] = (short)(dc * dequant[0]); + + // decode AC components, see JPEG spec + k = 1; + do { + unsigned int zig; + int c, r, s; + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + if (s > j->code_bits) + return stbi__err("bad huffman code", "Combined length longer than code bits available"); + j->code_buffer <<= s; + j->code_bits -= s; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)((r >> 8) * dequant[zig]); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (rs != 0xf0) + break; // end block + k += 16; + } else { + k += r; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)(stbi__extend_receive(j, s) * dequant[zig]); + } + } + } while (k < 64); + return 1; +} + +static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg * j, short data[64], stbi__huffman * hdc, int b) { + int diff, dc; + int t; + if (j->spec_end != 0) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + + if (j->succ_high == 0) { + // first scan for DC coefficient, must be first + memset(data, 0, 64 * sizeof(data[0])); // 0 all the ac values now + t = stbi__jpeg_huff_decode(j, hdc); + if (t < 0 || t > 15) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + diff = t ? stbi__extend_receive(j, t) : 0; + + if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) + return stbi__err("bad delta", "Corrupt JPEG"); + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + if (!stbi__mul2shorts_valid(dc, 1 << j->succ_low)) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + data[0] = (short)(dc * (1 << j->succ_low)); + } else { + // refinement scan for DC coefficient + if (stbi__jpeg_get_bit(j)) + data[0] += (short)(1 << j->succ_low); + } + return 1; +} + +// @OPTIMIZE: store non-zigzagged during the decode passes, +// and only de-zigzag when dequantizing +static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg * j, short data[64], stbi__huffman * hac, stbi__int16 * fac) { + int k; + if (j->spec_start == 0) + return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->succ_high == 0) { + int shift = j->succ_low; + + if (j->eob_run) { + --j->eob_run; + return 1; + } + + k = j->spec_start; + do { + unsigned int zig; + int c, r, s; + if (j->code_bits < 16) + stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + if (s > j->code_bits) + return stbi__err("bad huffman code", "Combined length longer than code bits available"); + j->code_buffer <<= s; + j->code_bits -= s; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)((r >> 8) * (1 << shift)); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r); + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + --j->eob_run; + break; + } + k += 16; + } else { + k += r; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short)(stbi__extend_receive(j, s) * (1 << shift)); + } + } + } while (k <= j->spec_end); + } else { + // refinement scan for these AC coefficients + + short bit = (short)(1 << j->succ_low); + + if (j->eob_run) { + --j->eob_run; + for (k = j->spec_start; k <= j->spec_end; ++k) { + short * p = &data[stbi__jpeg_dezigzag[k]]; + if (*p != 0) + if (stbi__jpeg_get_bit(j)) + if ((*p & bit) == 0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } + } else { + k = j->spec_start; + do { + int r, s; + int rs = stbi__jpeg_huff_decode( + j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh + if (rs < 0) + return stbi__err("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r) - 1; + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + r = 64; // force end of block + } else { + // r=15 s=0 should write 16 0s, so we just do + // a run of 15 0s and then write s (which is 0), + // so we don't have to do anything special here + } + } else { + if (s != 1) + return stbi__err("bad huffman code", "Corrupt JPEG"); + // sign bit + if (stbi__jpeg_get_bit(j)) + s = bit; + else + s = -bit; + } + + // advance by r + while (k <= j->spec_end) { + short * p = &data[stbi__jpeg_dezigzag[k++]]; + if (*p != 0) { + if (stbi__jpeg_get_bit(j)) + if ((*p & bit) == 0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } else { + if (r == 0) { + *p = (short)s; + break; + } + --r; + } + } + } while (k <= j->spec_end); + } + } + return 1; +} + +// take a -128..127 value and stbi__clamp it and convert to 0..255 +stbi_inline static stbi_uc stbi__clamp(int x) { + // trick to use a single test to catch both cases + if ((unsigned int)x > 255) { + if (x < 0) + return 0; + if (x > 255) + return 255; + } + return (stbi_uc)x; +} + +#define stbi__f2f(x) ((int)(((x)*4096 + 0.5))) +#define stbi__fsh(x) ((x)*4096) + +// derived from jidctint -- DCT_ISLOW +#define STBI__IDCT_1D(s0, s1, s2, s3, s4, s5, s6, s7) \ + int t0, t1, t2, t3, p1, p2, p3, p4, p5, x0, x1, x2, x3; \ + p2 = s2; \ + p3 = s6; \ + p1 = (p2 + p3) * stbi__f2f(0.5411961f); \ + t2 = p1 + p3 * stbi__f2f(-1.847759065f); \ + t3 = p1 + p2 * stbi__f2f(0.765366865f); \ + p2 = s0; \ + p3 = s4; \ + t0 = stbi__fsh(p2 + p3); \ + t1 = stbi__fsh(p2 - p3); \ + x0 = t0 + t3; \ + x3 = t0 - t3; \ + x1 = t1 + t2; \ + x2 = t1 - t2; \ + t0 = s7; \ + t1 = s5; \ + t2 = s3; \ + t3 = s1; \ + p3 = t0 + t2; \ + p4 = t1 + t3; \ + p1 = t0 + t3; \ + p2 = t1 + t2; \ + p5 = (p3 + p4) * stbi__f2f(1.175875602f); \ + t0 = t0 * stbi__f2f(0.298631336f); \ + t1 = t1 * stbi__f2f(2.053119869f); \ + t2 = t2 * stbi__f2f(3.072711026f); \ + t3 = t3 * stbi__f2f(1.501321110f); \ + p1 = p5 + p1 * stbi__f2f(-0.899976223f); \ + p2 = p5 + p2 * stbi__f2f(-2.562915447f); \ + p3 = p3 * stbi__f2f(-1.961570560f); \ + p4 = p4 * stbi__f2f(-0.390180644f); \ + t3 += p1 + p4; \ + t2 += p2 + p3; \ + t1 += p2 + p4; \ + t0 += p1 + p3; + +static void stbi__idct_block(stbi_uc * out, int out_stride, short data[64]) { + int i, val[64], *v = val; + stbi_uc * o; + short * d = data; + + // columns + for (i = 0; i < 8; ++i, ++d, ++v) { + // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing + if (d[8] == 0 && d[16] == 0 && d[24] == 0 && d[32] == 0 && d[40] == 0 && d[48] == 0 && d[56] == 0) { + // no shortcut 0 seconds + // (1|2|3|4|5|6|7)==0 0 seconds + // all separate -0.047 seconds + // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds + int dcterm = d[0] * 4; + v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; + } else { + STBI__IDCT_1D(d[0], d[8], d[16], d[24], d[32], d[40], d[48], d[56]) + // constants scaled things up by 1<<12; let's bring them back + // down, but keep 2 extra bits of precision + x0 += 512; + x1 += 512; + x2 += 512; + x3 += 512; + v[0] = (x0 + t3) >> 10; + v[56] = (x0 - t3) >> 10; + v[8] = (x1 + t2) >> 10; + v[48] = (x1 - t2) >> 10; + v[16] = (x2 + t1) >> 10; + v[40] = (x2 - t1) >> 10; + v[24] = (x3 + t0) >> 10; + v[32] = (x3 - t0) >> 10; + } + } + + for (i = 0, v = val, o = out; i < 8; ++i, v += 8, o += out_stride) { + // no fast case since the first 1D IDCT spread components out + STBI__IDCT_1D(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]) + // constants scaled things up by 1<<12, plus we had 1<<2 from first + // loop, plus horizontal and vertical each scale by sqrt(8) so together + // we've got an extra 1<<3, so 1<<17 total we need to remove. + // so we want to round that, which means adding 0.5 * 1<<17, + // aka 65536. Also, we'll end up with -128 to 127 that we want + // to encode as 0..255 by adding 128, so we'll add that before the shift + x0 += 65536 + (128 << 17); + x1 += 65536 + (128 << 17); + x2 += 65536 + (128 << 17); + x3 += 65536 + (128 << 17); + // tried computing the shifts into temps, or'ing the temps to see + // if any were out of range, but that was slower + o[0] = stbi__clamp((x0 + t3) >> 17); + o[7] = stbi__clamp((x0 - t3) >> 17); + o[1] = stbi__clamp((x1 + t2) >> 17); + o[6] = stbi__clamp((x1 - t2) >> 17); + o[2] = stbi__clamp((x2 + t1) >> 17); + o[5] = stbi__clamp((x2 - t1) >> 17); + o[3] = stbi__clamp((x3 + t0) >> 17); + o[4] = stbi__clamp((x3 - t0) >> 17); + } +} + +#ifdef STBI_SSE2 +// sse2 integer IDCT. not the fastest possible implementation but it +// produces bit-identical results to the generic C version so it's +// fully "transparent". +static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { + // This is constructed to match our regular (generic) integer IDCT exactly. + __m128i row0, row1, row2, row3, row4, row5, row6, row7; + __m128i tmp; + +// dot product constant: even elems=x, odd elems=y +#define dct_const(x, y) _mm_setr_epi16((x), (y), (x), (y), (x), (y), (x), (y)) + +// out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) +// out(1) = c1[even]*x + c1[odd]*y +#define dct_rot(out0, out1, x, y, c0, c1) \ + __m128i c0##lo = _mm_unpacklo_epi16((x), (y)); \ + __m128i c0##hi = _mm_unpackhi_epi16((x), (y)); \ + __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ + __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ + __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ + __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) + +// out = in << 12 (in 16-bit, out 32-bit) +#define dct_widen(out, in) \ + __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ + __m128i out##_h = _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) + +// wide add +#define dct_wadd(out, a, b) \ + __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_add_epi32(a##_h, b##_h) + +// wide sub +#define dct_wsub(out, a, b) \ + __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) + +// butterfly a/b, add bias, then shift by "s" and pack +#define dct_bfly32o(out0, out1, a, b, bias, s) \ + { \ + __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ + __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ + dct_wadd(sum, abiased, b); \ + dct_wsub(dif, abiased, b); \ + out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ + out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ + } + +// 8-bit interleave step (for transposes) +#define dct_interleave8(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi8(a, b); \ + b = _mm_unpackhi_epi8(tmp, b) + +// 16-bit interleave step (for transposes) +#define dct_interleave16(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi16(a, b); \ + b = _mm_unpackhi_epi16(tmp, b) + +#define dct_pass(bias, shift) \ + { \ + /* even part */ \ + dct_rot(t2e, t3e, row2, row6, rot0_0, rot0_1); \ + __m128i sum04 = _mm_add_epi16(row0, row4); \ + __m128i dif04 = _mm_sub_epi16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + dct_rot(y0o, y2o, row7, row3, rot2_0, rot2_1); \ + dct_rot(y1o, y3o, row5, row1, rot3_0, rot3_1); \ + __m128i sum17 = _mm_add_epi16(row1, row7); \ + __m128i sum35 = _mm_add_epi16(row3, row5); \ + dct_rot(y4o, y5o, sum17, sum35, rot1_0, rot1_1); \ + dct_wadd(x4, y0o, y4o); \ + dct_wadd(x5, y1o, y5o); \ + dct_wadd(x6, y2o, y5o); \ + dct_wadd(x7, y3o, y4o); \ + dct_bfly32o(row0, row7, x0, x7, bias, shift); \ + dct_bfly32o(row1, row6, x1, x6, bias, shift); \ + dct_bfly32o(row2, row5, x2, x5, bias, shift); \ + dct_bfly32o(row3, row4, x3, x4, bias, shift); \ + } + + __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); + __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f(0.765366865f), stbi__f2f(0.5411961f)); + __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), stbi__f2f(1.175875602f)); + __m128i rot1_1 = dct_const(stbi__f2f(1.175875602f), stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); + __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f(0.298631336f), stbi__f2f(-1.961570560f)); + __m128i rot2_1 = dct_const(stbi__f2f(-1.961570560f), stbi__f2f(-1.961570560f) + stbi__f2f(3.072711026f)); + __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f(2.053119869f), stbi__f2f(-0.390180644f)); + __m128i rot3_1 = dct_const(stbi__f2f(-0.390180644f), stbi__f2f(-0.390180644f) + stbi__f2f(1.501321110f)); + + // rounding biases in column/row passes, see stbi__idct_block for explanation. + __m128i bias_0 = _mm_set1_epi32(512); + __m128i bias_1 = _mm_set1_epi32(65536 + (128 << 17)); + + // load + row0 = _mm_load_si128((const __m128i *)(data + 0 * 8)); + row1 = _mm_load_si128((const __m128i *)(data + 1 * 8)); + row2 = _mm_load_si128((const __m128i *)(data + 2 * 8)); + row3 = _mm_load_si128((const __m128i *)(data + 3 * 8)); + row4 = _mm_load_si128((const __m128i *)(data + 4 * 8)); + row5 = _mm_load_si128((const __m128i *)(data + 5 * 8)); + row6 = _mm_load_si128((const __m128i *)(data + 6 * 8)); + row7 = _mm_load_si128((const __m128i *)(data + 7 * 8)); + + // column pass + dct_pass(bias_0, 10); + + { + // 16bit 8x8 transpose pass 1 + dct_interleave16(row0, row4); + dct_interleave16(row1, row5); + dct_interleave16(row2, row6); + dct_interleave16(row3, row7); + + // transpose pass 2 + dct_interleave16(row0, row2); + dct_interleave16(row1, row3); + dct_interleave16(row4, row6); + dct_interleave16(row5, row7); + + // transpose pass 3 + dct_interleave16(row0, row1); + dct_interleave16(row2, row3); + dct_interleave16(row4, row5); + dct_interleave16(row6, row7); + } + + // row pass + dct_pass(bias_1, 17); + + { + // pack + __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 + __m128i p1 = _mm_packus_epi16(row2, row3); + __m128i p2 = _mm_packus_epi16(row4, row5); + __m128i p3 = _mm_packus_epi16(row6, row7); + + // 8bit 8x8 transpose pass 1 + dct_interleave8(p0, p2); // a0e0a1e1... + dct_interleave8(p1, p3); // c0g0c1g1... + + // transpose pass 2 + dct_interleave8(p0, p1); // a0c0e0g0... + dct_interleave8(p2, p3); // b0d0f0h0... + + // transpose pass 3 + dct_interleave8(p0, p2); // a0b0c0d0... + dct_interleave8(p1, p3); // a4b4c4d4... + + // store + _mm_storel_epi64((__m128i *)out, p0); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p0, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p2); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p2, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p1); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p1, 0x4e)); + out += out_stride; + _mm_storel_epi64((__m128i *)out, p3); + out += out_stride; + _mm_storel_epi64((__m128i *)out, _mm_shuffle_epi32(p3, 0x4e)); + } + +#undef dct_const +#undef dct_rot +#undef dct_widen +#undef dct_wadd +#undef dct_wsub +#undef dct_bfly32o +#undef dct_interleave8 +#undef dct_interleave16 +#undef dct_pass +} + +#endif // STBI_SSE2 + +#ifdef STBI_NEON + +// NEON integer IDCT. should produce bit-identical +// results to the generic C version. +static void stbi__idct_simd(stbi_uc * out, int out_stride, short data[64]) { + int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; + + int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); + int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); + int16x4_t rot0_2 = vdup_n_s16(stbi__f2f(0.765366865f)); + int16x4_t rot1_0 = vdup_n_s16(stbi__f2f(1.175875602f)); + int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); + int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); + int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); + int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); + int16x4_t rot3_0 = vdup_n_s16(stbi__f2f(0.298631336f)); + int16x4_t rot3_1 = vdup_n_s16(stbi__f2f(2.053119869f)); + int16x4_t rot3_2 = vdup_n_s16(stbi__f2f(3.072711026f)); + int16x4_t rot3_3 = vdup_n_s16(stbi__f2f(1.501321110f)); + +#define dct_long_mul(out, inq, coeff) \ + int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) + +#define dct_long_mac(out, acc, inq, coeff) \ + int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) + +#define dct_widen(out, inq) \ + int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ + int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) + +// wide add +#define dct_wadd(out, a, b) \ + int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vaddq_s32(a##_h, b##_h) + +// wide sub +#define dct_wsub(out, a, b) \ + int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vsubq_s32(a##_h, b##_h) + +// butterfly a/b, then shift using "shiftop" by "s" and pack +#define dct_bfly32o(out0, out1, a, b, shiftop, s) \ + { \ + dct_wadd(sum, a, b); \ + dct_wsub(dif, a, b); \ + out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ + out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ + } + +#define dct_pass(shiftop, shift) \ + { \ + /* even part */ \ + int16x8_t sum26 = vaddq_s16(row2, row6); \ + dct_long_mul(p1e, sum26, rot0_0); \ + dct_long_mac(t2e, p1e, row6, rot0_1); \ + dct_long_mac(t3e, p1e, row2, rot0_2); \ + int16x8_t sum04 = vaddq_s16(row0, row4); \ + int16x8_t dif04 = vsubq_s16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + int16x8_t sum15 = vaddq_s16(row1, row5); \ + int16x8_t sum17 = vaddq_s16(row1, row7); \ + int16x8_t sum35 = vaddq_s16(row3, row5); \ + int16x8_t sum37 = vaddq_s16(row3, row7); \ + int16x8_t sumodd = vaddq_s16(sum17, sum35); \ + dct_long_mul(p5o, sumodd, rot1_0); \ + dct_long_mac(p1o, p5o, sum17, rot1_1); \ + dct_long_mac(p2o, p5o, sum35, rot1_2); \ + dct_long_mul(p3o, sum37, rot2_0); \ + dct_long_mul(p4o, sum15, rot2_1); \ + dct_wadd(sump13o, p1o, p3o); \ + dct_wadd(sump24o, p2o, p4o); \ + dct_wadd(sump23o, p2o, p3o); \ + dct_wadd(sump14o, p1o, p4o); \ + dct_long_mac(x4, sump13o, row7, rot3_0); \ + dct_long_mac(x5, sump24o, row5, rot3_1); \ + dct_long_mac(x6, sump23o, row3, rot3_2); \ + dct_long_mac(x7, sump14o, row1, rot3_3); \ + dct_bfly32o(row0, row7, x0, x7, shiftop, shift); \ + dct_bfly32o(row1, row6, x1, x6, shiftop, shift); \ + dct_bfly32o(row2, row5, x2, x5, shiftop, shift); \ + dct_bfly32o(row3, row4, x3, x4, shiftop, shift); \ + } + + // load + row0 = vld1q_s16(data + 0 * 8); + row1 = vld1q_s16(data + 1 * 8); + row2 = vld1q_s16(data + 2 * 8); + row3 = vld1q_s16(data + 3 * 8); + row4 = vld1q_s16(data + 4 * 8); + row5 = vld1q_s16(data + 5 * 8); + row6 = vld1q_s16(data + 6 * 8); + row7 = vld1q_s16(data + 7 * 8); + + // add DC bias + row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); + + // column pass + dct_pass(vrshrn_n_s32, 10); + + // 16bit 8x8 transpose + { +// these three map to a single VTRN.16, VTRN.32, and VSWP, respectively. +// whether compilers actually get this is another story, sadly. +#define dct_trn16(x, y) \ + { \ + int16x8x2_t t = vtrnq_s16(x, y); \ + x = t.val[0]; \ + y = t.val[1]; \ + } +#define dct_trn32(x, y) \ + { \ + int32x4x2_t t = vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); \ + x = vreinterpretq_s16_s32(t.val[0]); \ + y = vreinterpretq_s16_s32(t.val[1]); \ + } +#define dct_trn64(x, y) \ + { \ + int16x8_t x0 = x; \ + int16x8_t y0 = y; \ + x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); \ + y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); \ + } + + // pass 1 + dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 + dct_trn16(row2, row3); + dct_trn16(row4, row5); + dct_trn16(row6, row7); + + // pass 2 + dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 + dct_trn32(row1, row3); + dct_trn32(row4, row6); + dct_trn32(row5, row7); + + // pass 3 + dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 + dct_trn64(row1, row5); + dct_trn64(row2, row6); + dct_trn64(row3, row7); + +#undef dct_trn16 +#undef dct_trn32 +#undef dct_trn64 + } + + // row pass + // vrshrn_n_s32 only supports shifts up to 16, we need + // 17. so do a non-rounding shift of 16 first then follow + // up with a rounding shift by 1. + dct_pass(vshrn_n_s32, 16); + + { + // pack and round + uint8x8_t p0 = vqrshrun_n_s16(row0, 1); + uint8x8_t p1 = vqrshrun_n_s16(row1, 1); + uint8x8_t p2 = vqrshrun_n_s16(row2, 1); + uint8x8_t p3 = vqrshrun_n_s16(row3, 1); + uint8x8_t p4 = vqrshrun_n_s16(row4, 1); + uint8x8_t p5 = vqrshrun_n_s16(row5, 1); + uint8x8_t p6 = vqrshrun_n_s16(row6, 1); + uint8x8_t p7 = vqrshrun_n_s16(row7, 1); + + // again, these can translate into one instruction, but often don't. +#define dct_trn8_8(x, y) \ + { \ + uint8x8x2_t t = vtrn_u8(x, y); \ + x = t.val[0]; \ + y = t.val[1]; \ + } +#define dct_trn8_16(x, y) \ + { \ + uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); \ + x = vreinterpret_u8_u16(t.val[0]); \ + y = vreinterpret_u8_u16(t.val[1]); \ + } +#define dct_trn8_32(x, y) \ + { \ + uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); \ + x = vreinterpret_u8_u32(t.val[0]); \ + y = vreinterpret_u8_u32(t.val[1]); \ + } + + // sadly can't use interleaved stores here since we only write + // 8 bytes to each scan line! + + // 8x8 8-bit transpose pass 1 + dct_trn8_8(p0, p1); + dct_trn8_8(p2, p3); + dct_trn8_8(p4, p5); + dct_trn8_8(p6, p7); + + // pass 2 + dct_trn8_16(p0, p2); + dct_trn8_16(p1, p3); + dct_trn8_16(p4, p6); + dct_trn8_16(p5, p7); + + // pass 3 + dct_trn8_32(p0, p4); + dct_trn8_32(p1, p5); + dct_trn8_32(p2, p6); + dct_trn8_32(p3, p7); + + // store + vst1_u8(out, p0); + out += out_stride; + vst1_u8(out, p1); + out += out_stride; + vst1_u8(out, p2); + out += out_stride; + vst1_u8(out, p3); + out += out_stride; + vst1_u8(out, p4); + out += out_stride; + vst1_u8(out, p5); + out += out_stride; + vst1_u8(out, p6); + out += out_stride; + vst1_u8(out, p7); + +#undef dct_trn8_8 +#undef dct_trn8_16 +#undef dct_trn8_32 + } + +#undef dct_long_mul +#undef dct_long_mac +#undef dct_widen +#undef dct_wadd +#undef dct_wsub +#undef dct_bfly32o +#undef dct_pass +} + +#endif // STBI_NEON + +#define STBI__MARKER_none 0xff +// if there's a pending marker from the entropy stream, return that +// otherwise, fetch from the stream and get a marker. if there's no +// marker, return 0xff, which is never a valid marker value +static stbi_uc stbi__get_marker(stbi__jpeg * j) { + stbi_uc x; + if (j->marker != STBI__MARKER_none) { + x = j->marker; + j->marker = STBI__MARKER_none; + return x; + } + x = stbi__get8(j->s); + if (x != 0xff) + return STBI__MARKER_none; + while (x == 0xff) + x = stbi__get8(j->s); // consume repeated 0xff fill bytes + return x; +} + +// in each scan, we'll have scan_n components, and the order +// of the components is specified by order[] +#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) + +// after a restart interval, stbi__jpeg_reset the entropy decoder and +// the dc prediction +static void stbi__jpeg_reset(stbi__jpeg * j) { + j->code_bits = 0; + j->code_buffer = 0; + j->nomore = 0; + j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = j->img_comp[3].dc_pred = 0; + j->marker = STBI__MARKER_none; + j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; + j->eob_run = 0; + // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, + // since we don't even allow 1<<30 pixels +} + +static int stbi__parse_entropy_coded_data(stbi__jpeg * z) { + stbi__jpeg_reset(z); + if (!z->progressive) { + if (z->scan_n == 1) { + int i, j; + STBI_SIMD_ALIGN(short, data[64]); + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, z->huff_dc + z->img_comp[n].hd, z->huff_ac + ha, z->fast_ac[ha], n, + z->dequant[z->img_comp[n].tq])) + return 0; + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + i * 8, z->img_comp[n].w2, data); + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + // if it's NOT a restart, then just bail, so we get corrupt data + // rather than no data + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } + } + return 1; + } else { // interleaved + int i, j, k, x, y; + STBI_SIMD_ALIGN(short, data[64]); + for (j = 0; j < z->img_mcu_y; ++j) { + for (i = 0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k = 0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y = 0; y < z->img_comp[n].v; ++y) { + for (x = 0; x < z->img_comp[n].h; ++x) { + int x2 = (i * z->img_comp[n].h + x) * 8; + int y2 = (j * z->img_comp[n].v + y) * 8; + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, z->huff_dc + z->img_comp[n].hd, z->huff_ac + ha, + z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) + return 0; + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * y2 + x2, z->img_comp[n].w2, + data); + } + } + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } + } + return 1; + } + } else { + if (z->scan_n == 1) { + int i, j; + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + short * data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + if (z->spec_start == 0) { + if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } else { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], z->fast_ac[ha])) + return 0; + } + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } + } + return 1; + } else { // interleaved + int i, j, k, x, y; + for (j = 0; j < z->img_mcu_y; ++j) { + for (i = 0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k = 0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y = 0; y < z->img_comp[n].v; ++y) { + for (x = 0; x < z->img_comp[n].h; ++x) { + int x2 = (i * z->img_comp[n].h + x); + int y2 = (j * z->img_comp[n].v + y); + short * data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w); + if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } + } + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) + stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) + return 1; + stbi__jpeg_reset(z); + } + } + } + return 1; + } + } +} + +static void stbi__jpeg_dequantize(short * data, stbi__uint16 * dequant) { + int i; + for (i = 0; i < 64; ++i) + data[i] *= dequant[i]; +} + +static void stbi__jpeg_finish(stbi__jpeg * z) { + if (z->progressive) { + // dequantize and idct the data + int i, j, n; + for (n = 0; n < z->s->img_n; ++n) { + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + short * data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + i * 8, z->img_comp[n].w2, data); + } + } + } + } +} + +static int stbi__process_marker(stbi__jpeg * z, int m) { + int L; + switch (m) { + case STBI__MARKER_none: // no marker found + return stbi__err("expected marker", "Corrupt JPEG"); + + case 0xDD: // DRI - specify restart interval + if (stbi__get16be(z->s) != 4) + return stbi__err("bad DRI len", "Corrupt JPEG"); + z->restart_interval = stbi__get16be(z->s); + return 1; + + case 0xDB: // DQT - define quantization table + L = stbi__get16be(z->s) - 2; + while (L > 0) { + int q = stbi__get8(z->s); + int p = q >> 4, sixteen = (p != 0); + int t = q & 15, i; + if (p != 0 && p != 1) + return stbi__err("bad DQT type", "Corrupt JPEG"); + if (t > 3) + return stbi__err("bad DQT table", "Corrupt JPEG"); + + for (i = 0; i < 64; ++i) + z->dequant[t][stbi__jpeg_dezigzag[i]] = (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); + L -= (sixteen ? 129 : 65); + } + return L == 0; + + case 0xC4: // DHT - define huffman table + L = stbi__get16be(z->s) - 2; + while (L > 0) { + stbi_uc * v; + int sizes[16], i, n = 0; + int q = stbi__get8(z->s); + int tc = q >> 4; + int th = q & 15; + if (tc > 1 || th > 3) + return stbi__err("bad DHT header", "Corrupt JPEG"); + for (i = 0; i < 16; ++i) { + sizes[i] = stbi__get8(z->s); + n += sizes[i]; + } + if (n > 256) + return stbi__err("bad DHT header", "Corrupt JPEG"); // Loop over i < n would write past end of values! + L -= 17; + if (tc == 0) { + if (!stbi__build_huffman(z->huff_dc + th, sizes)) + return 0; + v = z->huff_dc[th].values; + } else { + if (!stbi__build_huffman(z->huff_ac + th, sizes)) + return 0; + v = z->huff_ac[th].values; + } + for (i = 0; i < n; ++i) + v[i] = stbi__get8(z->s); + if (tc != 0) + stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); + L -= n; + } + return L == 0; + } + + // check for comment block or APP blocks + if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { + L = stbi__get16be(z->s); + if (L < 2) { + if (m == 0xFE) + return stbi__err("bad COM len", "Corrupt JPEG"); + else + return stbi__err("bad APP len", "Corrupt JPEG"); + } + L -= 2; + + if (m == 0xE0 && L >= 5) { // JFIF APP0 segment + static const unsigned char tag[5] = {'J', 'F', 'I', 'F', '\0'}; + int ok = 1; + int i; + for (i = 0; i < 5; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 5; + if (ok) + z->jfif = 1; + } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment + static const unsigned char tag[6] = {'A', 'd', 'o', 'b', 'e', '\0'}; + int ok = 1; + int i; + for (i = 0; i < 6; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 6; + if (ok) { + stbi__get8(z->s); // version + stbi__get16be(z->s); // flags0 + stbi__get16be(z->s); // flags1 + z->app14_color_transform = stbi__get8(z->s); // color transform + L -= 6; + } + } + + stbi__skip(z->s, L); + return 1; + } + + return stbi__err("unknown marker", "Corrupt JPEG"); +} + +// after we see SOS +static int stbi__process_scan_header(stbi__jpeg * z) { + int i; + int Ls = stbi__get16be(z->s); + z->scan_n = stbi__get8(z->s); + if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int)z->s->img_n) + return stbi__err("bad SOS component count", "Corrupt JPEG"); + if (Ls != 6 + 2 * z->scan_n) + return stbi__err("bad SOS len", "Corrupt JPEG"); + for (i = 0; i < z->scan_n; ++i) { + int id = stbi__get8(z->s), which; + int q = stbi__get8(z->s); + for (which = 0; which < z->s->img_n; ++which) + if (z->img_comp[which].id == id) + break; + if (which == z->s->img_n) + return 0; // no match + z->img_comp[which].hd = q >> 4; + if (z->img_comp[which].hd > 3) + return stbi__err("bad DC huff", "Corrupt JPEG"); + z->img_comp[which].ha = q & 15; + if (z->img_comp[which].ha > 3) + return stbi__err("bad AC huff", "Corrupt JPEG"); + z->order[i] = which; + } + + { + int aa; + z->spec_start = stbi__get8(z->s); + z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 + aa = stbi__get8(z->s); + z->succ_high = (aa >> 4); + z->succ_low = (aa & 15); + if (z->progressive) { + if (z->spec_start > 63 || z->spec_end > 63 || z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) + return stbi__err("bad SOS", "Corrupt JPEG"); + } else { + if (z->spec_start != 0) + return stbi__err("bad SOS", "Corrupt JPEG"); + if (z->succ_high != 0 || z->succ_low != 0) + return stbi__err("bad SOS", "Corrupt JPEG"); + z->spec_end = 63; + } + } + + return 1; +} + +static int stbi__free_jpeg_components(stbi__jpeg * z, int ncomp, int why) { + int i; + for (i = 0; i < ncomp; ++i) { + if (z->img_comp[i].raw_data) { + STBI_FREE(z->img_comp[i].raw_data); + z->img_comp[i].raw_data = NULL; + z->img_comp[i].data = NULL; + } + if (z->img_comp[i].raw_coeff) { + STBI_FREE(z->img_comp[i].raw_coeff); + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].linebuf) { + STBI_FREE(z->img_comp[i].linebuf); + z->img_comp[i].linebuf = NULL; + } + } + return why; +} + +static int stbi__process_frame_header(stbi__jpeg * z, int scan) { + stbi__context * s = z->s; + int Lf, p, i, q, h_max = 1, v_max = 1, c; + Lf = stbi__get16be(s); + if (Lf < 11) + return stbi__err("bad SOF len", "Corrupt JPEG"); // JPEG + p = stbi__get8(s); + if (p != 8) + return stbi__err("only 8-bit", "JPEG format not supported: 8-bit only"); // JPEG baseline + s->img_y = stbi__get16be(s); + if (s->img_y == 0) + return stbi__err("no header height", + "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG + s->img_x = stbi__get16be(s); + if (s->img_x == 0) + return stbi__err("0 width", "Corrupt JPEG"); // JPEG requires + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + c = stbi__get8(s); + if (c != 3 && c != 1 && c != 4) + return stbi__err("bad component count", "Corrupt JPEG"); + s->img_n = c; + for (i = 0; i < c; ++i) { + z->img_comp[i].data = NULL; + z->img_comp[i].linebuf = NULL; + } + + if (Lf != 8 + 3 * s->img_n) + return stbi__err("bad SOF len", "Corrupt JPEG"); + + z->rgb = 0; + for (i = 0; i < s->img_n; ++i) { + static const unsigned char rgb[3] = {'R', 'G', 'B'}; + z->img_comp[i].id = stbi__get8(s); + if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) + ++z->rgb; + q = stbi__get8(s); + z->img_comp[i].h = (q >> 4); + if (!z->img_comp[i].h || z->img_comp[i].h > 4) + return stbi__err("bad H", "Corrupt JPEG"); + z->img_comp[i].v = q & 15; + if (!z->img_comp[i].v || z->img_comp[i].v > 4) + return stbi__err("bad V", "Corrupt JPEG"); + z->img_comp[i].tq = stbi__get8(s); + if (z->img_comp[i].tq > 3) + return stbi__err("bad TQ", "Corrupt JPEG"); + } + + if (scan != STBI__SCAN_load) + return 1; + + if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) + return stbi__err("too large", "Image too large to decode"); + + for (i = 0; i < s->img_n; ++i) { + if (z->img_comp[i].h > h_max) + h_max = z->img_comp[i].h; + if (z->img_comp[i].v > v_max) + v_max = z->img_comp[i].v; + } + + // check that plane subsampling factors are integer ratios; our resamplers can't deal with fractional ratios + // and I've never seen a non-corrupted JPEG file actually use them + for (i = 0; i < s->img_n; ++i) { + if (h_max % z->img_comp[i].h != 0) + return stbi__err("bad H", "Corrupt JPEG"); + if (v_max % z->img_comp[i].v != 0) + return stbi__err("bad V", "Corrupt JPEG"); + } + + // compute interleaved mcu info + z->img_h_max = h_max; + z->img_v_max = v_max; + z->img_mcu_w = h_max * 8; + z->img_mcu_h = v_max * 8; + // these sizes can't be more than 17 bits + z->img_mcu_x = (s->img_x + z->img_mcu_w - 1) / z->img_mcu_w; + z->img_mcu_y = (s->img_y + z->img_mcu_h - 1) / z->img_mcu_h; + + for (i = 0; i < s->img_n; ++i) { + // number of effective pixels (e.g. for non-interleaved MCU) + z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max - 1) / h_max; + z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max - 1) / v_max; + // to simplify generation, we'll allocate enough memory to decode + // the bogus oversized data from using interleaved MCUs and their + // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't + // discard the extra data until colorspace conversion + // + // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked earlier) + // so these muls can't overflow with 32-bit ints (which we require) + z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; + z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; + z->img_comp[i].coeff = 0; + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].linebuf = NULL; + z->img_comp[i].raw_data = stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); + if (z->img_comp[i].raw_data == NULL) + return stbi__free_jpeg_components(z, i + 1, stbi__err("outofmem", "Out of memory")); + // align blocks for idct using mmx/sse + z->img_comp[i].data = (stbi_uc *)(((size_t)z->img_comp[i].raw_data + 15) & ~15); + if (z->progressive) { + // w2, h2 are multiples of 8 (see above) + z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; + z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; + z->img_comp[i].raw_coeff = stbi__malloc_mad3(z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); + if (z->img_comp[i].raw_coeff == NULL) + return stbi__free_jpeg_components(z, i + 1, stbi__err("outofmem", "Out of memory")); + z->img_comp[i].coeff = (short *)(((size_t)z->img_comp[i].raw_coeff + 15) & ~15); + } + } + + return 1; +} + +// use comparisons since in some cases we handle more than one case (e.g. SOF) +#define stbi__DNL(x) ((x) == 0xdc) +#define stbi__SOI(x) ((x) == 0xd8) +#define stbi__EOI(x) ((x) == 0xd9) +#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) +#define stbi__SOS(x) ((x) == 0xda) + +#define stbi__SOF_progressive(x) ((x) == 0xc2) + +static int stbi__decode_jpeg_header(stbi__jpeg * z, int scan) { + int m; + z->jfif = 0; + z->app14_color_transform = -1; // valid values are 0,1,2 + z->marker = STBI__MARKER_none; // initialize cached marker to empty + m = stbi__get_marker(z); + if (!stbi__SOI(m)) + return stbi__err("no SOI", "Corrupt JPEG"); + if (scan == STBI__SCAN_type) + return 1; + m = stbi__get_marker(z); + while (!stbi__SOF(m)) { + if (!stbi__process_marker(z, m)) + return 0; + m = stbi__get_marker(z); + while (m == STBI__MARKER_none) { + // some files have extra padding after their blocks, so ok, we'll scan + if (stbi__at_eof(z->s)) + return stbi__err("no SOF", "Corrupt JPEG"); + m = stbi__get_marker(z); + } + } + z->progressive = stbi__SOF_progressive(m); + if (!stbi__process_frame_header(z, scan)) + return 0; + return 1; +} + +static int stbi__skip_jpeg_junk_at_end(stbi__jpeg * j) { + // some JPEGs have junk at end, skip over it but if we find what looks + // like a valid marker, resume there + while (!stbi__at_eof(j->s)) { + int x = stbi__get8(j->s); + while (x == 255) { // might be a marker + if (stbi__at_eof(j->s)) + return STBI__MARKER_none; + x = stbi__get8(j->s); + if (x != 0x00 && x != 0xff) { + // not a stuffed zero or lead-in to another marker, looks + // like an actual marker, return it + return x; + } + // stuffed zero has x=0 now which ends the loop, meaning we go + // back to regular scan loop. + // repeated 0xff keeps trying to read the next byte of the marker. + } + } + return STBI__MARKER_none; +} + +// decode image to YCbCr format +static int stbi__decode_jpeg_image(stbi__jpeg * j) { + int m; + for (m = 0; m < 4; m++) { + j->img_comp[m].raw_data = NULL; + j->img_comp[m].raw_coeff = NULL; + } + j->restart_interval = 0; + if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) + return 0; + m = stbi__get_marker(j); + while (!stbi__EOI(m)) { + if (stbi__SOS(m)) { + if (!stbi__process_scan_header(j)) + return 0; + if (!stbi__parse_entropy_coded_data(j)) + return 0; + if (j->marker == STBI__MARKER_none) { + j->marker = stbi__skip_jpeg_junk_at_end(j); + // if we reach eof without hitting a marker, stbi__get_marker() below will fail and we'll eventually return 0 + } + m = stbi__get_marker(j); + if (STBI__RESTART(m)) + m = stbi__get_marker(j); + } else if (stbi__DNL(m)) { + int Ld = stbi__get16be(j->s); + stbi__uint32 NL = stbi__get16be(j->s); + if (Ld != 4) + return stbi__err("bad DNL len", "Corrupt JPEG"); + if (NL != j->s->img_y) + return stbi__err("bad DNL height", "Corrupt JPEG"); + m = stbi__get_marker(j); + } else { + if (!stbi__process_marker(j, m)) + return 1; + m = stbi__get_marker(j); + } + } + if (j->progressive) + stbi__jpeg_finish(j); + return 1; +} + +// static jfif-centered resampling (across block boundaries) + +typedef stbi_uc * (*resample_row_func)(stbi_uc * out, stbi_uc * in0, stbi_uc * in1, int w, int hs); + +#define stbi__div4(x) ((stbi_uc)((x) >> 2)) + +static stbi_uc * resample_row_1(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { + STBI_NOTUSED(out); + STBI_NOTUSED(in_far); + STBI_NOTUSED(w); + STBI_NOTUSED(hs); + return in_near; +} + +static stbi_uc * stbi__resample_row_v_2(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { + // need to generate two samples vertically for every one in input + int i; + STBI_NOTUSED(hs); + for (i = 0; i < w; ++i) + out[i] = stbi__div4(3 * in_near[i] + in_far[i] + 2); + return out; +} + +static stbi_uc * stbi__resample_row_h_2(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { + // need to generate two samples horizontally for every one in input + int i; + stbi_uc * input = in_near; + + if (w == 1) { + // if only one sample, can't do any interpolation + out[0] = out[1] = input[0]; + return out; + } + + out[0] = input[0]; + out[1] = stbi__div4(input[0] * 3 + input[1] + 2); + for (i = 1; i < w - 1; ++i) { + int n = 3 * input[i] + 2; + out[i * 2 + 0] = stbi__div4(n + input[i - 1]); + out[i * 2 + 1] = stbi__div4(n + input[i + 1]); + } + out[i * 2 + 0] = stbi__div4(input[w - 2] * 3 + input[w - 1] + 2); + out[i * 2 + 1] = input[w - 1]; + + STBI_NOTUSED(in_far); + STBI_NOTUSED(hs); + + return out; +} + +#define stbi__div16(x) ((stbi_uc)((x) >> 4)) + +static stbi_uc * stbi__resample_row_hv_2(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { + // need to generate 2x2 samples for every one in input + int i, t0, t1; + if (w == 1) { + out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3 * in_near[0] + in_far[0]; + out[0] = stbi__div4(t1 + 2); + for (i = 1; i < w; ++i) { + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); + } + out[w * 2 - 1] = stbi__div4(t1 + 2); + + STBI_NOTUSED(hs); + + return out; +} + +#if defined(STBI_SSE2) || defined(STBI_NEON) +static stbi_uc * stbi__resample_row_hv_2_simd(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { + // need to generate 2x2 samples for every one in input + int i = 0, t0, t1; + + if (w == 1) { + out[0] = out[1] = stbi__div4(3 * in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3 * in_near[0] + in_far[0]; + // process groups of 8 pixels for as long as we can. + // note we can't handle the last pixel in a row in this loop + // because we need to handle the filter boundary conditions. + for (; i < ((w - 1) & ~7); i += 8) { +#if defined(STBI_SSE2) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + __m128i zero = _mm_setzero_si128(); + __m128i farb = _mm_loadl_epi64((__m128i *)(in_far + i)); + __m128i nearb = _mm_loadl_epi64((__m128i *)(in_near + i)); + __m128i farw = _mm_unpacklo_epi8(farb, zero); + __m128i nearw = _mm_unpacklo_epi8(nearb, zero); + __m128i diff = _mm_sub_epi16(farw, nearw); + __m128i nears = _mm_slli_epi16(nearw, 2); + __m128i curr = _mm_add_epi16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + __m128i prv0 = _mm_slli_si128(curr, 2); + __m128i nxt0 = _mm_srli_si128(curr, 2); + __m128i prev = _mm_insert_epi16(prv0, t1, 0); + __m128i next = _mm_insert_epi16(nxt0, 3 * in_near[i + 8] + in_far[i + 8], 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + __m128i bias = _mm_set1_epi16(8); + __m128i curs = _mm_slli_epi16(curr, 2); + __m128i prvd = _mm_sub_epi16(prev, curr); + __m128i nxtd = _mm_sub_epi16(next, curr); + __m128i curb = _mm_add_epi16(curs, bias); + __m128i even = _mm_add_epi16(prvd, curb); + __m128i odd = _mm_add_epi16(nxtd, curb); + + // interleave even and odd pixels, then undo scaling. + __m128i int0 = _mm_unpacklo_epi16(even, odd); + __m128i int1 = _mm_unpackhi_epi16(even, odd); + __m128i de0 = _mm_srli_epi16(int0, 4); + __m128i de1 = _mm_srli_epi16(int1, 4); + + // pack and write output + __m128i outv = _mm_packus_epi16(de0, de1); + _mm_storeu_si128((__m128i *)(out + i * 2), outv); +#elif defined(STBI_NEON) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + uint8x8_t farb = vld1_u8(in_far + i); + uint8x8_t nearb = vld1_u8(in_near + i); + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); + int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); + int16x8_t curr = vaddq_s16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + int16x8_t prv0 = vextq_s16(curr, curr, 7); + int16x8_t nxt0 = vextq_s16(curr, curr, 1); + int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); + int16x8_t next = vsetq_lane_s16(3 * in_near[i + 8] + in_far[i + 8], nxt0, 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + int16x8_t curs = vshlq_n_s16(curr, 2); + int16x8_t prvd = vsubq_s16(prev, curr); + int16x8_t nxtd = vsubq_s16(next, curr); + int16x8_t even = vaddq_s16(curs, prvd); + int16x8_t odd = vaddq_s16(curs, nxtd); + + // undo scaling and round, then store with even/odd phases interleaved + uint8x8x2_t o; + o.val[0] = vqrshrun_n_s16(even, 4); + o.val[1] = vqrshrun_n_s16(odd, 4); + vst2_u8(out + i * 2, o); +#endif + + // "previous" value for next iter + t1 = 3 * in_near[i + 7] + in_far[i + 7]; + } + + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); + + for (++i; i < w; ++i) { + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2 - 1] = stbi__div16(3 * t0 + t1 + 8); + out[i * 2] = stbi__div16(3 * t1 + t0 + 8); + } + out[w * 2 - 1] = stbi__div4(t1 + 2); + + STBI_NOTUSED(hs); + + return out; +} +#endif + +static stbi_uc * stbi__resample_row_generic(stbi_uc * out, stbi_uc * in_near, stbi_uc * in_far, int w, int hs) { + // resample with nearest-neighbor + int i, j; + STBI_NOTUSED(in_far); + for (i = 0; i < w; ++i) + for (j = 0; j < hs; ++j) + out[i * hs + j] = in_near[i]; + return out; +} + +// this is a reduced-precision calculation of YCbCr-to-RGB introduced +// to make sure the code produces the same results in both SIMD and scalar +#define stbi__float2fixed(x) (((int)((x)*4096.0f + 0.5f)) << 8) +static void stbi__YCbCr_to_RGB_row(stbi_uc * out, const stbi_uc * y, const stbi_uc * pcb, const stbi_uc * pcr, int count, + int step) { + int i; + for (i = 0; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int r, g, b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr * stbi__float2fixed(1.40200f); + g = y_fixed + (cr * -stbi__float2fixed(0.71414f)) + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb * stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned)r > 255) { + if (r < 0) + r = 0; + else + r = 255; + } + if ((unsigned)g > 255) { + if (g < 0) + g = 0; + else + g = 255; + } + if ((unsigned)b > 255) { + if (b < 0) + b = 0; + else + b = 255; + } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } +} + +#if defined(STBI_SSE2) || defined(STBI_NEON) +static void stbi__YCbCr_to_RGB_simd(stbi_uc * out, stbi_uc const * y, stbi_uc const * pcb, stbi_uc const * pcr, int count, + int step) { + int i = 0; + +#ifdef STBI_SSE2 + // step == 3 is pretty ugly on the final interleave, and i'm not convinced + // it's useful in practice (you wouldn't use it for textures, for example). + // so just accelerate step == 4 case. + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + __m128i signflip = _mm_set1_epi8(-0x80); + __m128i cr_const0 = _mm_set1_epi16((short)(1.40200f * 4096.0f + 0.5f)); + __m128i cr_const1 = _mm_set1_epi16(-(short)(0.71414f * 4096.0f + 0.5f)); + __m128i cb_const0 = _mm_set1_epi16(-(short)(0.34414f * 4096.0f + 0.5f)); + __m128i cb_const1 = _mm_set1_epi16((short)(1.77200f * 4096.0f + 0.5f)); + __m128i y_bias = _mm_set1_epi8((char)(unsigned char)128); + __m128i xw = _mm_set1_epi16(255); // alpha channel + + for (; i + 7 < count; i += 8) { + // load + __m128i y_bytes = _mm_loadl_epi64((__m128i *)(y + i)); + __m128i cr_bytes = _mm_loadl_epi64((__m128i *)(pcr + i)); + __m128i cb_bytes = _mm_loadl_epi64((__m128i *)(pcb + i)); + __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 + __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 + + // unpack to short (and left-shift cr, cb by 8) + __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); + __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); + __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); + + // color transform + __m128i yws = _mm_srli_epi16(yw, 4); + __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); + __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); + __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); + __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); + __m128i rws = _mm_add_epi16(cr0, yws); + __m128i gwt = _mm_add_epi16(cb0, yws); + __m128i bws = _mm_add_epi16(yws, cb1); + __m128i gws = _mm_add_epi16(gwt, cr1); + + // descale + __m128i rw = _mm_srai_epi16(rws, 4); + __m128i bw = _mm_srai_epi16(bws, 4); + __m128i gw = _mm_srai_epi16(gws, 4); + + // back to byte, set up for transpose + __m128i brb = _mm_packus_epi16(rw, bw); + __m128i gxb = _mm_packus_epi16(gw, xw); + + // transpose to interleave channels + __m128i t0 = _mm_unpacklo_epi8(brb, gxb); + __m128i t1 = _mm_unpackhi_epi8(brb, gxb); + __m128i o0 = _mm_unpacklo_epi16(t0, t1); + __m128i o1 = _mm_unpackhi_epi16(t0, t1); + + // store + _mm_storeu_si128((__m128i *)(out + 0), o0); + _mm_storeu_si128((__m128i *)(out + 16), o1); + out += 32; + } + } +#endif + +#ifdef STBI_NEON + // in this version, step=3 support would be easy to add. but is there demand? + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + uint8x8_t signflip = vdup_n_u8(0x80); + int16x8_t cr_const0 = vdupq_n_s16((short)(1.40200f * 4096.0f + 0.5f)); + int16x8_t cr_const1 = vdupq_n_s16(-(short)(0.71414f * 4096.0f + 0.5f)); + int16x8_t cb_const0 = vdupq_n_s16(-(short)(0.34414f * 4096.0f + 0.5f)); + int16x8_t cb_const1 = vdupq_n_s16((short)(1.77200f * 4096.0f + 0.5f)); + + for (; i + 7 < count; i += 8) { + // load + uint8x8_t y_bytes = vld1_u8(y + i); + uint8x8_t cr_bytes = vld1_u8(pcr + i); + uint8x8_t cb_bytes = vld1_u8(pcb + i); + int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); + int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); + + // expand to s16 + int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); + int16x8_t crw = vshll_n_s8(cr_biased, 7); + int16x8_t cbw = vshll_n_s8(cb_biased, 7); + + // color transform + int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); + int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); + int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); + int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); + int16x8_t rws = vaddq_s16(yws, cr0); + int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); + int16x8_t bws = vaddq_s16(yws, cb1); + + // undo scaling, round, convert to byte + uint8x8x4_t o; + o.val[0] = vqrshrun_n_s16(rws, 4); + o.val[1] = vqrshrun_n_s16(gws, 4); + o.val[2] = vqrshrun_n_s16(bws, 4); + o.val[3] = vdup_n_u8(255); + + // store, interleaving r/g/b/a + vst4_u8(out, o); + out += 8 * 4; + } + } +#endif + + for (; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int r, g, b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr * stbi__float2fixed(1.40200f); + g = y_fixed + cr * -stbi__float2fixed(0.71414f) + ((cb * -stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb * stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned)r > 255) { + if (r < 0) + r = 0; + else + r = 255; + } + if ((unsigned)g > 255) { + if (g < 0) + g = 0; + else + g = 255; + } + if ((unsigned)b > 255) { + if (b < 0) + b = 0; + else + b = 255; + } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } +} +#endif + +// set up the kernels +static void stbi__setup_jpeg(stbi__jpeg * j) { + j->idct_block_kernel = stbi__idct_block; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; + +#ifdef STBI_SSE2 + if (stbi__sse2_available()) { + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + } +#endif + +#ifdef STBI_NEON + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; +#endif +} + +// clean up the temporary component buffers +static void stbi__cleanup_jpeg(stbi__jpeg * j) { stbi__free_jpeg_components(j, j->s->img_n, 0); } + +typedef struct { + resample_row_func resample; + stbi_uc *line0, *line1; + int hs, vs; // expansion factor in each axis + int w_lores; // horizontal pixels pre-expansion + int ystep; // how far through vertical expansion we are + int ypos; // which pre-expansion row we're on +} stbi__resample; + +// fast 0..255 * 0..255 => 0..255 rounded multiplication +static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) { + unsigned int t = x * y + 128; + return (stbi_uc)((t + (t >> 8)) >> 8); +} + +static stbi_uc * load_jpeg_image(stbi__jpeg * z, int * out_x, int * out_y, int * comp, int req_comp) { + int n, decode_n, is_rgb; + z->s->img_n = 0; // make stbi__cleanup_jpeg safe + + // validate req_comp + if (req_comp < 0 || req_comp > 4) + return stbi__errpuc("bad req_comp", "Internal error"); + + // load a jpeg image from whichever source, but leave in YCbCr format + if (!stbi__decode_jpeg_image(z)) { + stbi__cleanup_jpeg(z); + return NULL; + } + + // determine actual number of components to generate + n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; + + is_rgb = z->s->img_n == 3 && (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); + + if (z->s->img_n == 3 && n < 3 && !is_rgb) + decode_n = 1; + else + decode_n = z->s->img_n; + + // nothing to do if no components requested; check this now to avoid + // accessing uninitialized coutput[0] later + if (decode_n <= 0) { + stbi__cleanup_jpeg(z); + return NULL; + } + + // resample and color-convert + { + int k; + unsigned int i, j; + stbi_uc * output; + stbi_uc * coutput[4] = {NULL, NULL, NULL, NULL}; + + stbi__resample res_comp[4]; + + for (k = 0; k < decode_n; ++k) { + stbi__resample * r = &res_comp[k]; + + // allocate line buffer big enough for upsampling off the edges + // with upsample factor of 4 + z->img_comp[k].linebuf = (stbi_uc *)stbi__malloc(z->s->img_x + 3); + if (!z->img_comp[k].linebuf) { + stbi__cleanup_jpeg(z); + return stbi__errpuc("outofmem", "Out of memory"); + } + + r->hs = z->img_h_max / z->img_comp[k].h; + r->vs = z->img_v_max / z->img_comp[k].v; + r->ystep = r->vs >> 1; + r->w_lores = (z->s->img_x + r->hs - 1) / r->hs; + r->ypos = 0; + r->line0 = r->line1 = z->img_comp[k].data; + + if (r->hs == 1 && r->vs == 1) + r->resample = resample_row_1; + else if (r->hs == 1 && r->vs == 2) + r->resample = stbi__resample_row_v_2; + else if (r->hs == 2 && r->vs == 1) + r->resample = stbi__resample_row_h_2; + else if (r->hs == 2 && r->vs == 2) + r->resample = z->resample_row_hv_2_kernel; + else + r->resample = stbi__resample_row_generic; + } + + // can't error after this so, this is safe + output = (stbi_uc *)stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); + if (!output) { + stbi__cleanup_jpeg(z); + return stbi__errpuc("outofmem", "Out of memory"); + } + + // now go ahead and resample + for (j = 0; j < z->s->img_y; ++j) { + stbi_uc * out = output + n * z->s->img_x * j; + for (k = 0; k < decode_n; ++k) { + stbi__resample * r = &res_comp[k]; + int y_bot = r->ystep >= (r->vs >> 1); + coutput[k] = r->resample(z->img_comp[k].linebuf, y_bot ? r->line1 : r->line0, y_bot ? r->line0 : r->line1, + r->w_lores, r->hs); + if (++r->ystep >= r->vs) { + r->ystep = 0; + r->line0 = r->line1; + if (++r->ypos < z->img_comp[k].y) + r->line1 += z->img_comp[k].w2; + } + } + if (n >= 3) { + stbi_uc * y = coutput[0]; + if (z->s->img_n == 3) { + if (is_rgb) { + for (i = 0; i < z->s->img_x; ++i) { + out[0] = y[i]; + out[1] = coutput[1][i]; + out[2] = coutput[2][i]; + out[3] = 255; + out += n; + } + } else { + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + } + } else if (z->s->img_n == 4) { + if (z->app14_color_transform == 0) { // CMYK + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(coutput[0][i], m); + out[1] = stbi__blinn_8x8(coutput[1][i], m); + out[2] = stbi__blinn_8x8(coutput[2][i], m); + out[3] = 255; + out += n; + } + } else if (z->app14_color_transform == 2) { // YCCK + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(255 - out[0], m); + out[1] = stbi__blinn_8x8(255 - out[1], m); + out[2] = stbi__blinn_8x8(255 - out[2], m); + out += n; + } + } else { // YCbCr + alpha? Ignore the fourth channel for now + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + } + } else + for (i = 0; i < z->s->img_x; ++i) { + out[0] = out[1] = out[2] = y[i]; + out[3] = 255; // not used if n==3 + out += n; + } + } else { + if (is_rgb) { + if (n == 1) + for (i = 0; i < z->s->img_x; ++i) + *out++ = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + else { + for (i = 0; i < z->s->img_x; ++i, out += 2) { + out[0] = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + out[1] = 255; + } + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { + for (i = 0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); + stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); + stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); + out[0] = stbi__compute_y(r, g, b); + out[1] = 255; + out += n; + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { + for (i = 0; i < z->s->img_x; ++i) { + out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); + out[1] = 255; + out += n; + } + } else { + stbi_uc * y = coutput[0]; + if (n == 1) + for (i = 0; i < z->s->img_x; ++i) + out[i] = y[i]; + else + for (i = 0; i < z->s->img_x; ++i) { + *out++ = y[i]; + *out++ = 255; + } + } + } + } + stbi__cleanup_jpeg(z); + *out_x = z->s->img_x; + *out_y = z->s->img_y; + if (comp) + *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output + return output; + } +} + +static void * stbi__jpeg_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { + unsigned char * result; + stbi__jpeg * j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); + if (!j) + return stbi__errpuc("outofmem", "Out of memory"); + memset(j, 0, sizeof(stbi__jpeg)); + STBI_NOTUSED(ri); + j->s = s; + stbi__setup_jpeg(j); + result = load_jpeg_image(j, x, y, comp, req_comp); + STBI_FREE(j); + return result; +} + +static int stbi__jpeg_test(stbi__context * s) { + int r; + stbi__jpeg * j = (stbi__jpeg *)stbi__malloc(sizeof(stbi__jpeg)); + if (!j) + return stbi__err("outofmem", "Out of memory"); + memset(j, 0, sizeof(stbi__jpeg)); + j->s = s; + stbi__setup_jpeg(j); + r = stbi__decode_jpeg_header(j, STBI__SCAN_type); + stbi__rewind(s); + STBI_FREE(j); + return r; +} + +static int stbi__jpeg_info_raw(stbi__jpeg * j, int * x, int * y, int * comp) { + if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { + stbi__rewind(j->s); + return 0; + } + if (x) + *x = j->s->img_x; + if (y) + *y = j->s->img_y; + if (comp) + *comp = j->s->img_n >= 3 ? 3 : 1; + return 1; +} + +static int stbi__jpeg_info(stbi__context * s, int * x, int * y, int * comp) { + int result; + stbi__jpeg * j = (stbi__jpeg *)(stbi__malloc(sizeof(stbi__jpeg))); + if (!j) + return stbi__err("outofmem", "Out of memory"); + memset(j, 0, sizeof(stbi__jpeg)); + j->s = s; + result = stbi__jpeg_info_raw(j, x, y, comp); + STBI_FREE(j); + return result; +} +#endif + +// public domain zlib decode v0.2 Sean Barrett 2006-11-18 +// simple implementation +// - all input must be provided in an upfront buffer +// - all output is written to a single output buffer (can malloc/realloc) +// performance +// - fast huffman + +#ifndef STBI_NO_ZLIB + +// fast-way is faster to check than jpeg huffman, but slow way is slower +#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables +#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) +#define STBI__ZNSYMS 288 // number of symbols in literal/length alphabet + +// zlib-style huffman encoding +// (jpegs packs from left, zlib from right, so can't share code) +typedef struct { + stbi__uint16 fast[1 << STBI__ZFAST_BITS]; + stbi__uint16 firstcode[16]; + int maxcode[17]; + stbi__uint16 firstsymbol[16]; + stbi_uc size[STBI__ZNSYMS]; + stbi__uint16 value[STBI__ZNSYMS]; +} stbi__zhuffman; + +stbi_inline static int stbi__bitreverse16(int n) { + n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); + n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); + n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); + n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); + return n; +} + +stbi_inline static int stbi__bit_reverse(int v, int bits) { + STBI_ASSERT(bits <= 16); + // to bit reverse n bits, reverse 16 and shift + // e.g. 11 bits, bit reverse and shift away 5 + return stbi__bitreverse16(v) >> (16 - bits); +} + +static int stbi__zbuild_huffman(stbi__zhuffman * z, const stbi_uc * sizelist, int num) { + int i, k = 0; + int code, next_code[16], sizes[17]; + + // DEFLATE spec for generating codes + memset(sizes, 0, sizeof(sizes)); + memset(z->fast, 0, sizeof(z->fast)); + for (i = 0; i < num; ++i) + ++sizes[sizelist[i]]; + sizes[0] = 0; + for (i = 1; i < 16; ++i) + if (sizes[i] > (1 << i)) + return stbi__err("bad sizes", "Corrupt PNG"); + code = 0; + for (i = 1; i < 16; ++i) { + next_code[i] = code; + z->firstcode[i] = (stbi__uint16)code; + z->firstsymbol[i] = (stbi__uint16)k; + code = (code + sizes[i]); + if (sizes[i]) + if (code - 1 >= (1 << i)) + return stbi__err("bad codelengths", "Corrupt PNG"); + z->maxcode[i] = code << (16 - i); // preshift for inner loop + code <<= 1; + k += sizes[i]; + } + z->maxcode[16] = 0x10000; // sentinel + for (i = 0; i < num; ++i) { + int s = sizelist[i]; + if (s) { + int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; + stbi__uint16 fastv = (stbi__uint16)((s << 9) | i); + z->size[c] = (stbi_uc)s; + z->value[c] = (stbi__uint16)i; + if (s <= STBI__ZFAST_BITS) { + int j = stbi__bit_reverse(next_code[s], s); + while (j < (1 << STBI__ZFAST_BITS)) { + z->fast[j] = fastv; + j += (1 << s); + } + } + ++next_code[s]; + } + } + return 1; +} + +// zlib-from-memory implementation for PNG reading +// because PNG allows splitting the zlib stream arbitrarily, +// and it's annoying structurally to have PNG call ZLIB call PNG, +// we require PNG read all the IDATs and combine them into a single +// memory buffer + +typedef struct { + stbi_uc *zbuffer, *zbuffer_end; + int num_bits; + stbi__uint32 code_buffer; + + char * zout; + char * zout_start; + char * zout_end; + int z_expandable; + + stbi__zhuffman z_length, z_distance; +} stbi__zbuf; + +stbi_inline static int stbi__zeof(stbi__zbuf * z) { return (z->zbuffer >= z->zbuffer_end); } + +stbi_inline static stbi_uc stbi__zget8(stbi__zbuf * z) { return stbi__zeof(z) ? 0 : *z->zbuffer++; } + +static void stbi__fill_bits(stbi__zbuf * z) { + do { + if (z->code_buffer >= (1U << z->num_bits)) { + z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ + return; + } + z->code_buffer |= (unsigned int)stbi__zget8(z) << z->num_bits; + z->num_bits += 8; + } while (z->num_bits <= 24); +} + +stbi_inline static unsigned int stbi__zreceive(stbi__zbuf * z, int n) { + unsigned int k; + if (z->num_bits < n) + stbi__fill_bits(z); + k = z->code_buffer & ((1 << n) - 1); + z->code_buffer >>= n; + z->num_bits -= n; + return k; +} + +static int stbi__zhuffman_decode_slowpath(stbi__zbuf * a, stbi__zhuffman * z) { + int b, s, k; + // not resolved by fast table, so compute it the slow way + // use jpeg approach, which requires MSbits at top + k = stbi__bit_reverse(a->code_buffer, 16); + for (s = STBI__ZFAST_BITS + 1;; ++s) + if (k < z->maxcode[s]) + break; + if (s >= 16) + return -1; // invalid code! + // code size is s, so: + b = (k >> (16 - s)) - z->firstcode[s] + z->firstsymbol[s]; + if (b >= STBI__ZNSYMS) + return -1; // some data was corrupt somewhere! + if (z->size[b] != s) + return -1; // was originally an assert, but report failure instead. + a->code_buffer >>= s; + a->num_bits -= s; + return z->value[b]; +} + +stbi_inline static int stbi__zhuffman_decode(stbi__zbuf * a, stbi__zhuffman * z) { + int b, s; + if (a->num_bits < 16) { + if (stbi__zeof(a)) { + return -1; /* report error for unexpected end of data. */ + } + stbi__fill_bits(a); + } + b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; + if (b) { + s = b >> 9; + a->code_buffer >>= s; + a->num_bits -= s; + return b & 511; + } + return stbi__zhuffman_decode_slowpath(a, z); +} + +static int stbi__zexpand(stbi__zbuf * z, char * zout, int n) // need to make room for n bytes +{ + char * q; + unsigned int cur, limit, old_limit; + z->zout = zout; + if (!z->z_expandable) + return stbi__err("output buffer limit", "Corrupt PNG"); + cur = (unsigned int)(z->zout - z->zout_start); + limit = old_limit = (unsigned)(z->zout_end - z->zout_start); + if (UINT_MAX - cur < (unsigned)n) + return stbi__err("outofmem", "Out of memory"); + while (cur + n > limit) { + if (limit > UINT_MAX / 2) + return stbi__err("outofmem", "Out of memory"); + limit *= 2; + } + q = (char *)STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); + STBI_NOTUSED(old_limit); + if (q == NULL) + return stbi__err("outofmem", "Out of memory"); + z->zout_start = q; + z->zout = q + cur; + z->zout_end = q + limit; + return 1; +} + +static const int stbi__zlength_base[31] = {3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, + 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0}; + +static const int stbi__zlength_extra[31] = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, + 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0, 0}; + +static const int stbi__zdist_base[32] = {1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, + 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, + 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577, 0, 0}; + +static const int stbi__zdist_extra[32] = {0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, + 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13}; + +static int stbi__parse_huffman_block(stbi__zbuf * a) { + char * zout = a->zout; + for (;;) { + int z = stbi__zhuffman_decode(a, &a->z_length); + if (z < 256) { + if (z < 0) + return stbi__err("bad huffman code", "Corrupt PNG"); // error in huffman codes + if (zout >= a->zout_end) { + if (!stbi__zexpand(a, zout, 1)) + return 0; + zout = a->zout; + } + *zout++ = (char)z; + } else { + stbi_uc * p; + int len, dist; + if (z == 256) { + a->zout = zout; + return 1; + } + if (z >= 286) + return stbi__err("bad huffman code", + "Corrupt PNG"); // per DEFLATE, length codes 286 and 287 must not appear in compressed data + z -= 257; + len = stbi__zlength_base[z]; + if (stbi__zlength_extra[z]) + len += stbi__zreceive(a, stbi__zlength_extra[z]); + z = stbi__zhuffman_decode(a, &a->z_distance); + if (z < 0 || z >= 30) + return stbi__err("bad huffman code", + "Corrupt PNG"); // per DEFLATE, distance codes 30 and 31 must not appear in compressed data + dist = stbi__zdist_base[z]; + if (stbi__zdist_extra[z]) + dist += stbi__zreceive(a, stbi__zdist_extra[z]); + if (zout - a->zout_start < dist) + return stbi__err("bad dist", "Corrupt PNG"); + if (zout + len > a->zout_end) { + if (!stbi__zexpand(a, zout, len)) + return 0; + zout = a->zout; + } + p = (stbi_uc *)(zout - dist); + if (dist == 1) { // run of one byte; common in images. + stbi_uc v = *p; + if (len) { + do + *zout++ = v; + while (--len); + } + } else { + if (len) { + do + *zout++ = *p++; + while (--len); + } + } + } + } +} + +static int stbi__compute_huffman_codes(stbi__zbuf * a) { + static const stbi_uc length_dezigzag[19] = {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; + stbi__zhuffman z_codelength; + stbi_uc lencodes[286 + 32 + 137]; // padding for maximum single op + stbi_uc codelength_sizes[19]; + int i, n; + + int hlit = stbi__zreceive(a, 5) + 257; + int hdist = stbi__zreceive(a, 5) + 1; + int hclen = stbi__zreceive(a, 4) + 4; + int ntot = hlit + hdist; + + memset(codelength_sizes, 0, sizeof(codelength_sizes)); + for (i = 0; i < hclen; ++i) { + int s = stbi__zreceive(a, 3); + codelength_sizes[length_dezigzag[i]] = (stbi_uc)s; + } + if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) + return 0; + + n = 0; + while (n < ntot) { + int c = stbi__zhuffman_decode(a, &z_codelength); + if (c < 0 || c >= 19) + return stbi__err("bad codelengths", "Corrupt PNG"); + if (c < 16) + lencodes[n++] = (stbi_uc)c; + else { + stbi_uc fill = 0; + if (c == 16) { + c = stbi__zreceive(a, 2) + 3; + if (n == 0) + return stbi__err("bad codelengths", "Corrupt PNG"); + fill = lencodes[n - 1]; + } else if (c == 17) { + c = stbi__zreceive(a, 3) + 3; + } else if (c == 18) { + c = stbi__zreceive(a, 7) + 11; + } else { + return stbi__err("bad codelengths", "Corrupt PNG"); + } + if (ntot - n < c) + return stbi__err("bad codelengths", "Corrupt PNG"); + memset(lencodes + n, fill, c); + n += c; + } + } + if (n != ntot) + return stbi__err("bad codelengths", "Corrupt PNG"); + if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) + return 0; + if (!stbi__zbuild_huffman(&a->z_distance, lencodes + hlit, hdist)) + return 0; + return 1; +} + +static int stbi__parse_uncompressed_block(stbi__zbuf * a) { + stbi_uc header[4]; + int len, nlen, k; + if (a->num_bits & 7) + stbi__zreceive(a, a->num_bits & 7); // discard + // drain the bit-packed data into header + k = 0; + while (a->num_bits > 0) { + header[k++] = (stbi_uc)(a->code_buffer & 255); // suppress MSVC run-time check + a->code_buffer >>= 8; + a->num_bits -= 8; + } + if (a->num_bits < 0) + return stbi__err("zlib corrupt", "Corrupt PNG"); + // now fill header the normal way + while (k < 4) + header[k++] = stbi__zget8(a); + len = header[1] * 256 + header[0]; + nlen = header[3] * 256 + header[2]; + if (nlen != (len ^ 0xffff)) + return stbi__err("zlib corrupt", "Corrupt PNG"); + if (a->zbuffer + len > a->zbuffer_end) + return stbi__err("read past buffer", "Corrupt PNG"); + if (a->zout + len > a->zout_end) + if (!stbi__zexpand(a, a->zout, len)) + return 0; + memcpy(a->zout, a->zbuffer, len); + a->zbuffer += len; + a->zout += len; + return 1; +} + +static int stbi__parse_zlib_header(stbi__zbuf * a) { + int cmf = stbi__zget8(a); + int cm = cmf & 15; + /* int cinfo = cmf >> 4; */ + int flg = stbi__zget8(a); + if (stbi__zeof(a)) + return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec + if ((cmf * 256 + flg) % 31 != 0) + return stbi__err("bad zlib header", "Corrupt PNG"); // zlib spec + if (flg & 32) + return stbi__err("no preset dict", "Corrupt PNG"); // preset dictionary not allowed in png + if (cm != 8) + return stbi__err("bad compression", "Corrupt PNG"); // DEFLATE required for png + // window = 1 << (8 + cinfo)... but who cares, we fully buffer output + return 1; +} + +static const stbi_uc stbi__zdefault_length[STBI__ZNSYMS] = { + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8}; +static const stbi_uc stbi__zdefault_distance[32] = {5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}; +/* +Init algorithm: +{ + int i; // use <= to match clearly with spec + for (i=0; i <= 143; ++i) stbi__zdefault_length[i] = 8; + for ( ; i <= 255; ++i) stbi__zdefault_length[i] = 9; + for ( ; i <= 279; ++i) stbi__zdefault_length[i] = 7; + for ( ; i <= 287; ++i) stbi__zdefault_length[i] = 8; + + for (i=0; i <= 31; ++i) stbi__zdefault_distance[i] = 5; +} +*/ + +static int stbi__parse_zlib(stbi__zbuf * a, int parse_header) { + int final, type; + if (parse_header) + if (!stbi__parse_zlib_header(a)) + return 0; + a->num_bits = 0; + a->code_buffer = 0; + do { + final = stbi__zreceive(a, 1); + type = stbi__zreceive(a, 2); + if (type == 0) { + if (!stbi__parse_uncompressed_block(a)) + return 0; + } else if (type == 3) { + return 0; + } else { + if (type == 1) { + // use fixed code lengths + if (!stbi__zbuild_huffman(&a->z_length, stbi__zdefault_length, STBI__ZNSYMS)) + return 0; + if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) + return 0; + } else { + if (!stbi__compute_huffman_codes(a)) + return 0; + } + if (!stbi__parse_huffman_block(a)) + return 0; + } + } while (!final); + return 1; +} + +static int stbi__do_zlib(stbi__zbuf * a, char * obuf, int olen, int exp, int parse_header) { + a->zout_start = obuf; + a->zout = obuf; + a->zout_end = obuf + olen; + a->z_expandable = exp; + + return stbi__parse_zlib(a, parse_header); +} + +STBIDEF char * stbi_zlib_decode_malloc_guesssize(const char * buffer, int len, int initial_size, int * outlen) { + stbi__zbuf a; + char * p = (char *)stbi__malloc(initial_size); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF char * stbi_zlib_decode_malloc(char const * buffer, int len, int * outlen) { + return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); +} + +STBIDEF char * stbi_zlib_decode_malloc_guesssize_headerflag(const char * buffer, int len, int initial_size, int * outlen, + int parse_header) { + stbi__zbuf a; + char * p = (char *)stbi__malloc(initial_size); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_buffer(char * obuffer, int olen, char const * ibuffer, int ilen) { + stbi__zbuf a; + a.zbuffer = (stbi_uc *)ibuffer; + a.zbuffer_end = (stbi_uc *)ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) + return (int)(a.zout - a.zout_start); + else + return -1; +} + +STBIDEF char * stbi_zlib_decode_noheader_malloc(char const * buffer, int len, int * outlen) { + stbi__zbuf a; + char * p = (char *)stbi__malloc(16384); + if (p == NULL) + return NULL; + a.zbuffer = (stbi_uc *)buffer; + a.zbuffer_end = (stbi_uc *)buffer + len; + if (stbi__do_zlib(&a, p, 16384, 1, 0)) { + if (outlen) + *outlen = (int)(a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_noheader_buffer(char * obuffer, int olen, const char * ibuffer, int ilen) { + stbi__zbuf a; + a.zbuffer = (stbi_uc *)ibuffer; + a.zbuffer_end = (stbi_uc *)ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) + return (int)(a.zout - a.zout_start); + else + return -1; +} +#endif + +// public domain "baseline" PNG decoder v0.10 Sean Barrett 2006-11-18 +// simple implementation +// - only 8-bit samples +// - no CRC checking +// - allocates lots of intermediate memory +// - avoids problem of streaming data between subsystems +// - avoids explicit window management +// performance +// - uses stb_zlib, a PD zlib implementation with fast huffman decoding + +#ifndef STBI_NO_PNG +typedef struct { + stbi__uint32 length; + stbi__uint32 type; +} stbi__pngchunk; + +static stbi__pngchunk stbi__get_chunk_header(stbi__context * s) { + stbi__pngchunk c; + c.length = stbi__get32be(s); + c.type = stbi__get32be(s); + return c; +} + +static int stbi__check_png_header(stbi__context * s) { + static const stbi_uc png_sig[8] = {137, 80, 78, 71, 13, 10, 26, 10}; + int i; + for (i = 0; i < 8; ++i) + if (stbi__get8(s) != png_sig[i]) + return stbi__err("bad png sig", "Not a PNG"); + return 1; +} + +typedef struct { + stbi__context * s; + stbi_uc *idata, *expanded, *out; + int depth; +} stbi__png; + +enum { + STBI__F_none = 0, + STBI__F_sub = 1, + STBI__F_up = 2, + STBI__F_avg = 3, + STBI__F_paeth = 4, + // synthetic filters used for first scanline to avoid needing a dummy row of 0s + STBI__F_avg_first, + STBI__F_paeth_first +}; + +static stbi_uc first_row_filter[5] = {STBI__F_none, STBI__F_sub, STBI__F_none, STBI__F_avg_first, STBI__F_paeth_first}; + +static int stbi__paeth(int a, int b, int c) { + int p = a + b - c; + int pa = abs(p - a); + int pb = abs(p - b); + int pc = abs(p - c); + if (pa <= pb && pa <= pc) + return a; + if (pb <= pc) + return b; + return c; +} + +static const stbi_uc stbi__depth_scale_table[9] = {0, 0xff, 0x55, 0, 0x11, 0, 0, 0, 0x01}; + +// create the png data from post-deflated data +static int stbi__create_png_image_raw(stbi__png * a, stbi_uc * raw, stbi__uint32 raw_len, int out_n, stbi__uint32 x, + stbi__uint32 y, int depth, int color) { + int bytes = (depth == 16 ? 2 : 1); + stbi__context * s = a->s; + stbi__uint32 i, j, stride = x * out_n * bytes; + stbi__uint32 img_len, img_width_bytes; + int k; + int img_n = s->img_n; // copy it into a local for later + + int output_bytes = out_n * bytes; + int filter_bytes = img_n * bytes; + int width = x; + + STBI_ASSERT(out_n == s->img_n || out_n == s->img_n + 1); + a->out = (stbi_uc *)stbi__malloc_mad3(x, y, output_bytes, 0); // extra bytes to write off the end into + if (!a->out) + return stbi__err("outofmem", "Out of memory"); + + if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) + return stbi__err("too large", "Corrupt PNG"); + img_width_bytes = (((img_n * x * depth) + 7) >> 3); + img_len = (img_width_bytes + 1) * y; + + // we used to check for exact match between raw_len and img_len on non-interlaced PNGs, + // but issue #276 reported a PNG in the wild that had extra data at the end (all zeros), + // so just check for raw_len < img_len always. + if (raw_len < img_len) + return stbi__err("not enough pixels", "Corrupt PNG"); + + for (j = 0; j < y; ++j) { + stbi_uc * cur = a->out + stride * j; + stbi_uc * prior; + int filter = *raw++; + + if (filter > 4) + return stbi__err("invalid filter", "Corrupt PNG"); + + if (depth < 8) { + if (img_width_bytes > x) + return stbi__err("invalid width", "Corrupt PNG"); + cur += x * out_n - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place + filter_bytes = 1; + width = img_width_bytes; + } + prior = cur - stride; // bugfix: need to compute this after 'cur +=' computation above + + // if first row, use special filter that doesn't sample previous row + if (j == 0) + filter = first_row_filter[filter]; + + // handle first byte explicitly + for (k = 0; k < filter_bytes; ++k) { + switch (filter) { + case STBI__F_none: + cur[k] = raw[k]; + break; + case STBI__F_sub: + cur[k] = raw[k]; + break; + case STBI__F_up: + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + break; + case STBI__F_avg: + cur[k] = STBI__BYTECAST(raw[k] + (prior[k] >> 1)); + break; + case STBI__F_paeth: + cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(0, prior[k], 0)); + break; + case STBI__F_avg_first: + cur[k] = raw[k]; + break; + case STBI__F_paeth_first: + cur[k] = raw[k]; + break; + } + } + + if (depth == 8) { + if (img_n != out_n) + cur[img_n] = 255; // first pixel + raw += img_n; + cur += out_n; + prior += out_n; + } else if (depth == 16) { + if (img_n != out_n) { + cur[filter_bytes] = 255; // first pixel top byte + cur[filter_bytes + 1] = 255; // first pixel bottom byte + } + raw += filter_bytes; + cur += output_bytes; + prior += output_bytes; + } else { + raw += 1; + cur += 1; + prior += 1; + } + + // this is a little gross, so that we don't switch per-pixel or per-component + if (depth < 8 || img_n == out_n) { + int nk = (width - 1) * filter_bytes; +#define STBI__CASE(f) \ + case f: \ + for (k = 0; k < nk; ++k) + switch (filter) { + // "none" filter turns into a memcpy here; make that explicit. + case STBI__F_none: + memcpy(cur, raw, nk); + break; + STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k - filter_bytes]); } + break; + STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } + break; + STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k - filter_bytes]) >> 1)); } + break; + STBI__CASE(STBI__F_paeth) { + cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - filter_bytes], prior[k], prior[k - filter_bytes])); + } + break; + STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k - filter_bytes] >> 1)); } + break; + STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - filter_bytes], 0, 0)); } + break; + } +#undef STBI__CASE + raw += nk; + } else { + STBI_ASSERT(img_n + 1 == out_n); +#define STBI__CASE(f) \ + case f: \ + for (i = x - 1; i >= 1; --i, cur[filter_bytes] = 255, raw += filter_bytes, cur += output_bytes, prior += output_bytes) \ + for (k = 0; k < filter_bytes; ++k) + switch (filter) { + STBI__CASE(STBI__F_none) { cur[k] = raw[k]; } + break; + STBI__CASE(STBI__F_sub) { cur[k] = STBI__BYTECAST(raw[k] + cur[k - output_bytes]); } + break; + STBI__CASE(STBI__F_up) { cur[k] = STBI__BYTECAST(raw[k] + prior[k]); } + break; + STBI__CASE(STBI__F_avg) { cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k - output_bytes]) >> 1)); } + break; + STBI__CASE(STBI__F_paeth) { + cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - output_bytes], prior[k], prior[k - output_bytes])); + } + break; + STBI__CASE(STBI__F_avg_first) { cur[k] = STBI__BYTECAST(raw[k] + (cur[k - output_bytes] >> 1)); } + break; + STBI__CASE(STBI__F_paeth_first) { cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k - output_bytes], 0, 0)); } + break; + } +#undef STBI__CASE + + // the loop above sets the high byte of the pixels' alpha, but for + // 16 bit png files we also need the low byte set. we'll do that here. + if (depth == 16) { + cur = a->out + stride * j; // start at the beginning of the row again + for (i = 0; i < x; ++i, cur += output_bytes) { + cur[filter_bytes + 1] = 255; + } + } + } + } + + // we make a separate pass to expand bits to pixels; for performance, + // this could run two scanlines behind the above code, so it won't + // intefere with filtering but will still be in the cache. + if (depth < 8) { + for (j = 0; j < y; ++j) { + stbi_uc * cur = a->out + stride * j; + stbi_uc * in = a->out + stride * j + x * out_n - img_width_bytes; + // unpack 1/2/4-bit into a 8-bit buffer. allows us to keep the common 8-bit path optimal at minimal cost for + // 1/2/4-bit png guarante byte alignment, if width is not multiple of 8/4/2 we'll decode dummy trailing data that + // will be skipped in the later loop + stbi_uc scale = (color == 0) ? stbi__depth_scale_table[depth] : 1; // scale grayscale values to 0..255 range + + // note that the final byte might overshoot and write more data than desired. + // we can allocate enough data that this never writes out of memory, but it + // could also overwrite the next scanline. can it overwrite non-empty data + // on the next scanline? yes, consider 1-pixel-wide scanlines with 1-bit-per-pixel. + // so we need to explicitly clamp the final ones + + if (depth == 4) { + for (k = x * img_n; k >= 2; k -= 2, ++in) { + *cur++ = scale * ((*in >> 4)); + *cur++ = scale * ((*in) & 0x0f); + } + if (k > 0) + *cur++ = scale * ((*in >> 4)); + } else if (depth == 2) { + for (k = x * img_n; k >= 4; k -= 4, ++in) { + *cur++ = scale * ((*in >> 6)); + *cur++ = scale * ((*in >> 4) & 0x03); + *cur++ = scale * ((*in >> 2) & 0x03); + *cur++ = scale * ((*in) & 0x03); + } + if (k > 0) + *cur++ = scale * ((*in >> 6)); + if (k > 1) + *cur++ = scale * ((*in >> 4) & 0x03); + if (k > 2) + *cur++ = scale * ((*in >> 2) & 0x03); + } else if (depth == 1) { + for (k = x * img_n; k >= 8; k -= 8, ++in) { + *cur++ = scale * ((*in >> 7)); + *cur++ = scale * ((*in >> 6) & 0x01); + *cur++ = scale * ((*in >> 5) & 0x01); + *cur++ = scale * ((*in >> 4) & 0x01); + *cur++ = scale * ((*in >> 3) & 0x01); + *cur++ = scale * ((*in >> 2) & 0x01); + *cur++ = scale * ((*in >> 1) & 0x01); + *cur++ = scale * ((*in) & 0x01); + } + if (k > 0) + *cur++ = scale * ((*in >> 7)); + if (k > 1) + *cur++ = scale * ((*in >> 6) & 0x01); + if (k > 2) + *cur++ = scale * ((*in >> 5) & 0x01); + if (k > 3) + *cur++ = scale * ((*in >> 4) & 0x01); + if (k > 4) + *cur++ = scale * ((*in >> 3) & 0x01); + if (k > 5) + *cur++ = scale * ((*in >> 2) & 0x01); + if (k > 6) + *cur++ = scale * ((*in >> 1) & 0x01); + } + if (img_n != out_n) { + int q; + // insert alpha = 255 + cur = a->out + stride * j; + if (img_n == 1) { + for (q = x - 1; q >= 0; --q) { + cur[q * 2 + 1] = 255; + cur[q * 2 + 0] = cur[q]; + } + } else { + STBI_ASSERT(img_n == 3); + for (q = x - 1; q >= 0; --q) { + cur[q * 4 + 3] = 255; + cur[q * 4 + 2] = cur[q * 3 + 2]; + cur[q * 4 + 1] = cur[q * 3 + 1]; + cur[q * 4 + 0] = cur[q * 3 + 0]; + } + } + } + } + } else if (depth == 16) { + // force the image data from big-endian to platform-native. + // this is done in a separate pass due to the decoding relying + // on the data being untouched, but could probably be done + // per-line during decode if care is taken. + stbi_uc * cur = a->out; + stbi__uint16 * cur16 = (stbi__uint16 *)cur; + + for (i = 0; i < x * y * out_n; ++i, cur16++, cur += 2) { + *cur16 = (cur[0] << 8) | cur[1]; + } + } + + return 1; +} + +static int stbi__create_png_image(stbi__png * a, stbi_uc * image_data, stbi__uint32 image_data_len, int out_n, int depth, + int color, int interlaced) { + int bytes = (depth == 16 ? 2 : 1); + int out_bytes = out_n * bytes; + stbi_uc * final; + int p; + if (!interlaced) + return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, a->s->img_x, a->s->img_y, depth, color); + + // de-interlacing + final = (stbi_uc *)stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); + if (!final) + return stbi__err("outofmem", "Out of memory"); + for (p = 0; p < 7; ++p) { + int xorig[] = {0, 4, 0, 2, 0, 1, 0}; + int yorig[] = {0, 0, 4, 0, 2, 0, 1}; + int xspc[] = {8, 8, 4, 4, 2, 2, 1}; + int yspc[] = {8, 8, 8, 4, 4, 2, 2}; + int i, j, x, y; + // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 + x = (a->s->img_x - xorig[p] + xspc[p] - 1) / xspc[p]; + y = (a->s->img_y - yorig[p] + yspc[p] - 1) / yspc[p]; + if (x && y) { + stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; + if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, y, depth, color)) { + STBI_FREE(final); + return 0; + } + for (j = 0; j < y; ++j) { + for (i = 0; i < x; ++i) { + int out_y = j * yspc[p] + yorig[p]; + int out_x = i * xspc[p] + xorig[p]; + memcpy(final + out_y * a->s->img_x * out_bytes + out_x * out_bytes, a->out + (j * x + i) * out_bytes, + out_bytes); + } + } + STBI_FREE(a->out); + image_data += img_len; + image_data_len -= img_len; + } + } + a->out = final; + + return 1; +} + +static int stbi__compute_transparency(stbi__png * z, stbi_uc tc[3], int out_n) { + stbi__context * s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc * p = z->out; + + // compute color-based transparency, assuming we've + // already got 255 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); + + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 255); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; +} + +static int stbi__compute_transparency16(stbi__png * z, stbi__uint16 tc[3], int out_n) { + stbi__context * s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi__uint16 * p = (stbi__uint16 *)z->out; + + // compute color-based transparency, assuming we've + // already got 65535 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); + + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 65535); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; +} + +static int stbi__expand_png_palette(stbi__png * a, stbi_uc * palette, int len, int pal_img_n) { + stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; + stbi_uc *p, *temp_out, *orig = a->out; + + p = (stbi_uc *)stbi__malloc_mad2(pixel_count, pal_img_n, 0); + if (p == NULL) + return stbi__err("outofmem", "Out of memory"); + + // between here and free(out) below, exitting would leak + temp_out = p; + + if (pal_img_n == 3) { + for (i = 0; i < pixel_count; ++i) { + int n = orig[i] * 4; + p[0] = palette[n]; + p[1] = palette[n + 1]; + p[2] = palette[n + 2]; + p += 3; + } + } else { + for (i = 0; i < pixel_count; ++i) { + int n = orig[i] * 4; + p[0] = palette[n]; + p[1] = palette[n + 1]; + p[2] = palette[n + 2]; + p[3] = palette[n + 3]; + p += 4; + } + } + STBI_FREE(a->out); + a->out = temp_out; + + STBI_NOTUSED(len); + + return 1; +} + +static int stbi__unpremultiply_on_load_global = 0; +static int stbi__de_iphone_flag_global = 0; + +STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) { + stbi__unpremultiply_on_load_global = flag_true_if_should_unpremultiply; +} + +STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) { + stbi__de_iphone_flag_global = flag_true_if_should_convert; +} + +#ifndef STBI_THREAD_LOCAL +#define stbi__unpremultiply_on_load stbi__unpremultiply_on_load_global +#define stbi__de_iphone_flag stbi__de_iphone_flag_global +#else +static STBI_THREAD_LOCAL int stbi__unpremultiply_on_load_local, stbi__unpremultiply_on_load_set; +static STBI_THREAD_LOCAL int stbi__de_iphone_flag_local, stbi__de_iphone_flag_set; + +STBIDEF void stbi_set_unpremultiply_on_load_thread(int flag_true_if_should_unpremultiply) { + stbi__unpremultiply_on_load_local = flag_true_if_should_unpremultiply; + stbi__unpremultiply_on_load_set = 1; +} + +STBIDEF void stbi_convert_iphone_png_to_rgb_thread(int flag_true_if_should_convert) { + stbi__de_iphone_flag_local = flag_true_if_should_convert; + stbi__de_iphone_flag_set = 1; +} + +#define stbi__unpremultiply_on_load \ + (stbi__unpremultiply_on_load_set ? stbi__unpremultiply_on_load_local : stbi__unpremultiply_on_load_global) +#define stbi__de_iphone_flag (stbi__de_iphone_flag_set ? stbi__de_iphone_flag_local : stbi__de_iphone_flag_global) +#endif // STBI_THREAD_LOCAL + +static void stbi__de_iphone(stbi__png * z) { + stbi__context * s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc * p = z->out; + + if (s->img_out_n == 3) { // convert bgr to rgb + for (i = 0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 3; + } + } else { + STBI_ASSERT(s->img_out_n == 4); + if (stbi__unpremultiply_on_load) { + // convert bgr to rgb and unpremultiply + for (i = 0; i < pixel_count; ++i) { + stbi_uc a = p[3]; + stbi_uc t = p[0]; + if (a) { + stbi_uc half = a / 2; + p[0] = (p[2] * 255 + half) / a; + p[1] = (p[1] * 255 + half) / a; + p[2] = (t * 255 + half) / a; + } else { + p[0] = p[2]; + p[2] = t; + } + p += 4; + } + } else { + // convert bgr to rgb + for (i = 0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 4; + } + } + } +} + +#define STBI__PNG_TYPE(a, b, c, d) (((unsigned)(a) << 24) + ((unsigned)(b) << 16) + ((unsigned)(c) << 8) + (unsigned)(d)) + +static int stbi__parse_png_file(stbi__png * z, int scan, int req_comp) { + stbi_uc palette[1024], pal_img_n = 0; + stbi_uc has_trans = 0, tc[3] = {0}; + stbi__uint16 tc16[3]; + stbi__uint32 ioff = 0, idata_limit = 0, i, pal_len = 0; + int first = 1, k, interlace = 0, color = 0, is_iphone = 0; + stbi__context * s = z->s; + + z->expanded = NULL; + z->idata = NULL; + z->out = NULL; + + if (!stbi__check_png_header(s)) + return 0; + + if (scan == STBI__SCAN_type) + return 1; + + for (;;) { + stbi__pngchunk c = stbi__get_chunk_header(s); + switch (c.type) { + case STBI__PNG_TYPE('C', 'g', 'B', 'I'): + is_iphone = 1; + stbi__skip(s, c.length); + break; + case STBI__PNG_TYPE('I', 'H', 'D', 'R'): { + int comp, filter; + if (!first) + return stbi__err("multiple IHDR", "Corrupt PNG"); + first = 0; + if (c.length != 13) + return stbi__err("bad IHDR len", "Corrupt PNG"); + s->img_x = stbi__get32be(s); + s->img_y = stbi__get32be(s); + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + z->depth = stbi__get8(s); + if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && z->depth != 16) + return stbi__err("1/2/4/8/16-bit only", "PNG not supported: 1/2/4/8/16-bit only"); + color = stbi__get8(s); + if (color > 6) + return stbi__err("bad ctype", "Corrupt PNG"); + if (color == 3 && z->depth == 16) + return stbi__err("bad ctype", "Corrupt PNG"); + if (color == 3) + pal_img_n = 3; + else if (color & 1) + return stbi__err("bad ctype", "Corrupt PNG"); + comp = stbi__get8(s); + if (comp) + return stbi__err("bad comp method", "Corrupt PNG"); + filter = stbi__get8(s); + if (filter) + return stbi__err("bad filter method", "Corrupt PNG"); + interlace = stbi__get8(s); + if (interlace > 1) + return stbi__err("bad interlace method", "Corrupt PNG"); + if (!s->img_x || !s->img_y) + return stbi__err("0-pixel image", "Corrupt PNG"); + if (!pal_img_n) { + s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); + if ((1 << 30) / s->img_x / s->img_n < s->img_y) + return stbi__err("too large", "Image too large to decode"); + } else { + // if paletted, then pal_n is our final components, and + // img_n is # components to decompress/filter. + s->img_n = 1; + if ((1 << 30) / s->img_x / 4 < s->img_y) + return stbi__err("too large", "Corrupt PNG"); + } + // even with SCAN_header, have to scan to see if we have a tRNS + break; + } + + case STBI__PNG_TYPE('P', 'L', 'T', 'E'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (c.length > 256 * 3) + return stbi__err("invalid PLTE", "Corrupt PNG"); + pal_len = c.length / 3; + if (pal_len * 3 != c.length) + return stbi__err("invalid PLTE", "Corrupt PNG"); + for (i = 0; i < pal_len; ++i) { + palette[i * 4 + 0] = stbi__get8(s); + palette[i * 4 + 1] = stbi__get8(s); + palette[i * 4 + 2] = stbi__get8(s); + palette[i * 4 + 3] = 255; + } + break; + } + + case STBI__PNG_TYPE('t', 'R', 'N', 'S'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (z->idata) + return stbi__err("tRNS after IDAT", "Corrupt PNG"); + if (pal_img_n) { + if (scan == STBI__SCAN_header) { + s->img_n = 4; + return 1; + } + if (pal_len == 0) + return stbi__err("tRNS before PLTE", "Corrupt PNG"); + if (c.length > pal_len) + return stbi__err("bad tRNS len", "Corrupt PNG"); + pal_img_n = 4; + for (i = 0; i < c.length; ++i) + palette[i * 4 + 3] = stbi__get8(s); + } else { + if (!(s->img_n & 1)) + return stbi__err("tRNS with alpha", "Corrupt PNG"); + if (c.length != (stbi__uint32)s->img_n * 2) + return stbi__err("bad tRNS len", "Corrupt PNG"); + has_trans = 1; + // non-paletted with tRNS = constant alpha. if header-scanning, we can stop now. + if (scan == STBI__SCAN_header) { + ++s->img_n; + return 1; + } + if (z->depth == 16) { + for (k = 0; k < s->img_n; ++k) + tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is + } else { + for (k = 0; k < s->img_n; ++k) + tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * + stbi__depth_scale_table[z->depth]; // non 8-bit images will be larger + } + } + break; + } + + case STBI__PNG_TYPE('I', 'D', 'A', 'T'): { + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (pal_img_n && !pal_len) + return stbi__err("no PLTE", "Corrupt PNG"); + if (scan == STBI__SCAN_header) { + // header scan definitely stops at first IDAT + if (pal_img_n) + s->img_n = pal_img_n; + return 1; + } + if (c.length > (1u << 30)) + return stbi__err("IDAT size limit", "IDAT section larger than 2^30 bytes"); + if ((int)(ioff + c.length) < (int)ioff) + return 0; + if (ioff + c.length > idata_limit) { + stbi__uint32 idata_limit_old = idata_limit; + stbi_uc * p; + if (idata_limit == 0) + idata_limit = c.length > 4096 ? c.length : 4096; + while (ioff + c.length > idata_limit) + idata_limit *= 2; + STBI_NOTUSED(idata_limit_old); + p = (stbi_uc *)STBI_REALLOC_SIZED(z->idata, idata_limit_old, idata_limit); + if (p == NULL) + return stbi__err("outofmem", "Out of memory"); + z->idata = p; + } + if (!stbi__getn(s, z->idata + ioff, c.length)) + return stbi__err("outofdata", "Corrupt PNG"); + ioff += c.length; + break; + } + + case STBI__PNG_TYPE('I', 'E', 'N', 'D'): { + stbi__uint32 raw_len, bpl; + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if (scan != STBI__SCAN_load) + return 1; + if (z->idata == NULL) + return stbi__err("no IDAT", "Corrupt PNG"); + // initial guess for decoded data size to avoid unnecessary reallocs + bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component + raw_len = bpl * s->img_y * s->img_n /* pixels */ + s->img_y /* filter mode per row */; + z->expanded = (stbi_uc *)stbi_zlib_decode_malloc_guesssize_headerflag((char *)z->idata, ioff, raw_len, + (int *)&raw_len, !is_iphone); + if (z->expanded == NULL) + return 0; // zlib should set error + STBI_FREE(z->idata); + z->idata = NULL; + if ((req_comp == s->img_n + 1 && req_comp != 3 && !pal_img_n) || has_trans) + s->img_out_n = s->img_n + 1; + else + s->img_out_n = s->img_n; + if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, z->depth, color, interlace)) + return 0; + if (has_trans) { + if (z->depth == 16) { + if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) + return 0; + } else { + if (!stbi__compute_transparency(z, tc, s->img_out_n)) + return 0; + } + } + if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2) + stbi__de_iphone(z); + if (pal_img_n) { + // pal_img_n == 3 or 4 + s->img_n = pal_img_n; // record the actual colors we had + s->img_out_n = pal_img_n; + if (req_comp >= 3) + s->img_out_n = req_comp; + if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) + return 0; + } else if (has_trans) { + // non-paletted image with tRNS -> source image has (constant) alpha + ++s->img_n; + } + STBI_FREE(z->expanded); + z->expanded = NULL; + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + return 1; + } + + default: + // if critical, fail + if (first) + return stbi__err("first not IHDR", "Corrupt PNG"); + if ((c.type & (1 << 29)) == 0) { +#ifndef STBI_NO_FAILURE_STRINGS + // not threadsafe + static char invalid_chunk[] = "XXXX PNG chunk not known"; + invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); + invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); + invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); + invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); +#endif + return stbi__err(invalid_chunk, "PNG not supported: unknown PNG chunk type"); + } + stbi__skip(s, c.length); + break; + } + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + } +} + +static void * stbi__do_png(stbi__png * p, int * x, int * y, int * n, int req_comp, stbi__result_info * ri) { + void * result = NULL; + if (req_comp < 0 || req_comp > 4) + return stbi__errpuc("bad req_comp", "Internal error"); + if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { + if (p->depth <= 8) + ri->bits_per_channel = 8; + else if (p->depth == 16) + ri->bits_per_channel = 16; + else + return stbi__errpuc("bad bits_per_channel", "PNG not supported: unsupported color depth"); + result = p->out; + p->out = NULL; + if (req_comp && req_comp != p->s->img_out_n) { + if (ri->bits_per_channel == 8) + result = stbi__convert_format((unsigned char *)result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); + else + result = stbi__convert_format16((stbi__uint16 *)result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); + p->s->img_out_n = req_comp; + if (result == NULL) + return result; + } + *x = p->s->img_x; + *y = p->s->img_y; + if (n) + *n = p->s->img_n; + } + STBI_FREE(p->out); + p->out = NULL; + STBI_FREE(p->expanded); + p->expanded = NULL; + STBI_FREE(p->idata); + p->idata = NULL; + + return result; +} + +static void * stbi__png_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { + stbi__png p; + p.s = s; + return stbi__do_png(&p, x, y, comp, req_comp, ri); +} + +static int stbi__png_test(stbi__context * s) { + int r; + r = stbi__check_png_header(s); + stbi__rewind(s); + return r; +} + +static int stbi__png_info_raw(stbi__png * p, int * x, int * y, int * comp) { + if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { + stbi__rewind(p->s); + return 0; + } + if (x) + *x = p->s->img_x; + if (y) + *y = p->s->img_y; + if (comp) + *comp = p->s->img_n; + return 1; +} + +static int stbi__png_info(stbi__context * s, int * x, int * y, int * comp) { + stbi__png p; + p.s = s; + return stbi__png_info_raw(&p, x, y, comp); +} + +static int stbi__png_is16(stbi__context * s) { + stbi__png p; + p.s = s; + if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) + return 0; + if (p.depth != 16) { + stbi__rewind(p.s); + return 0; + } + return 1; +} +#endif + +// Microsoft/Windows BMP image + +#ifndef STBI_NO_BMP +static int stbi__bmp_test_raw(stbi__context * s) { + int r; + int sz; + if (stbi__get8(s) != 'B') + return 0; + if (stbi__get8(s) != 'M') + return 0; + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + stbi__get32le(s); // discard data offset + sz = stbi__get32le(s); + r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); + return r; +} + +static int stbi__bmp_test(stbi__context * s) { + int r = stbi__bmp_test_raw(s); + stbi__rewind(s); + return r; +} + +// returns 0..31 for the highest set bit +static int stbi__high_bit(unsigned int z) { + int n = 0; + if (z == 0) + return -1; + if (z >= 0x10000) { + n += 16; + z >>= 16; + } + if (z >= 0x00100) { + n += 8; + z >>= 8; + } + if (z >= 0x00010) { + n += 4; + z >>= 4; + } + if (z >= 0x00004) { + n += 2; + z >>= 2; + } + if (z >= 0x00002) { + n += 1; /* >>= 1;*/ + } + return n; +} + +static int stbi__bitcount(unsigned int a) { + a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 + a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 + a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits + a = (a + (a >> 8)); // max 16 per 8 bits + a = (a + (a >> 16)); // max 32 per 8 bits + return a & 0xff; +} + +// extract an arbitrarily-aligned N-bit value (N=bits) +// from v, and then make it 8-bits long and fractionally +// extend it to full full range. +static int stbi__shiftsigned(unsigned int v, int shift, int bits) { + static unsigned int mul_table[9] = { + 0, + 0xff /*0b11111111*/, + 0x55 /*0b01010101*/, + 0x49 /*0b01001001*/, + 0x11 /*0b00010001*/, + 0x21 /*0b00100001*/, + 0x41 /*0b01000001*/, + 0x81 /*0b10000001*/, + 0x01 /*0b00000001*/, + }; + static unsigned int shift_table[9] = { + 0, 0, 0, 1, 0, 2, 4, 6, 0, + }; + if (shift < 0) + v <<= -shift; + else + v >>= shift; + STBI_ASSERT(v < 256); + v >>= (8 - bits); + STBI_ASSERT(bits >= 0 && bits <= 8); + return (int)((unsigned)v * mul_table[bits]) >> shift_table[bits]; +} + +typedef struct { + int bpp, offset, hsz; + unsigned int mr, mg, mb, ma, all_a; + int extra_read; +} stbi__bmp_data; + +static int stbi__bmp_set_mask_defaults(stbi__bmp_data * info, int compress) { + // BI_BITFIELDS specifies masks explicitly, don't override + if (compress == 3) + return 1; + + if (compress == 0) { + if (info->bpp == 16) { + info->mr = 31u << 10; + info->mg = 31u << 5; + info->mb = 31u << 0; + } else if (info->bpp == 32) { + info->mr = 0xffu << 16; + info->mg = 0xffu << 8; + info->mb = 0xffu << 0; + info->ma = 0xffu << 24; + info->all_a = 0; // if all_a is 0 at end, then we loaded alpha channel but it was all 0 + } else { + // otherwise, use defaults, which is all-0 + info->mr = info->mg = info->mb = info->ma = 0; + } + return 1; + } + return 0; // error +} + +static void * stbi__bmp_parse_header(stbi__context * s, stbi__bmp_data * info) { + int hsz; + if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') + return stbi__errpuc("not BMP", "Corrupt BMP"); + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + info->offset = stbi__get32le(s); + info->hsz = hsz = stbi__get32le(s); + info->mr = info->mg = info->mb = info->ma = 0; + info->extra_read = 14; + + if (info->offset < 0) + return stbi__errpuc("bad BMP", "bad BMP"); + + if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) + return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); + if (hsz == 12) { + s->img_x = stbi__get16le(s); + s->img_y = stbi__get16le(s); + } else { + s->img_x = stbi__get32le(s); + s->img_y = stbi__get32le(s); + } + if (stbi__get16le(s) != 1) + return stbi__errpuc("bad BMP", "bad BMP"); + info->bpp = stbi__get16le(s); + if (hsz != 12) { + int compress = stbi__get32le(s); + if (compress == 1 || compress == 2) + return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); + if (compress >= 4) + return stbi__errpuc("BMP JPEG/PNG", + "BMP type not supported: unsupported compression"); // this includes PNG/JPEG modes + if (compress == 3 && info->bpp != 16 && info->bpp != 32) + return stbi__errpuc("bad BMP", "bad BMP"); // bitfields requires 16 or 32 bits/pixel + stbi__get32le(s); // discard sizeof + stbi__get32le(s); // discard hres + stbi__get32le(s); // discard vres + stbi__get32le(s); // discard colorsused + stbi__get32le(s); // discard max important + if (hsz == 40 || hsz == 56) { + if (hsz == 56) { + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + } + if (info->bpp == 16 || info->bpp == 32) { + if (compress == 0) { + stbi__bmp_set_mask_defaults(info, compress); + } else if (compress == 3) { + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->extra_read += 12; + // not documented, but generated by photoshop and handled by mspaint + if (info->mr == info->mg && info->mg == info->mb) { + // ?!?!? + return stbi__errpuc("bad BMP", "bad BMP"); + } + } else + return stbi__errpuc("bad BMP", "bad BMP"); + } + } else { + // V4/V5 header + int i; + if (hsz != 108 && hsz != 124) + return stbi__errpuc("bad BMP", "bad BMP"); + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->ma = stbi__get32le(s); + if (compress != 3) // override mr/mg/mb unless in BI_BITFIELDS mode, as per docs + stbi__bmp_set_mask_defaults(info, compress); + stbi__get32le(s); // discard color space + for (i = 0; i < 12; ++i) + stbi__get32le(s); // discard color space parameters + if (hsz == 124) { + stbi__get32le(s); // discard rendering intent + stbi__get32le(s); // discard offset of profile data + stbi__get32le(s); // discard size of profile data + stbi__get32le(s); // discard reserved + } + } + } + return (void *)1; +} + +static void * stbi__bmp_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { + stbi_uc * out; + unsigned int mr = 0, mg = 0, mb = 0, ma = 0, all_a; + stbi_uc pal[256][4]; + int psize = 0, i, j, width; + int flip_vertically, pad, target; + stbi__bmp_data info; + STBI_NOTUSED(ri); + + info.all_a = 255; + if (stbi__bmp_parse_header(s, &info) == NULL) + return NULL; // error code already set + + flip_vertically = ((int)s->img_y) > 0; + s->img_y = abs((int)s->img_y); + + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + mr = info.mr; + mg = info.mg; + mb = info.mb; + ma = info.ma; + all_a = info.all_a; + + if (info.hsz == 12) { + if (info.bpp < 24) + psize = (info.offset - info.extra_read - 24) / 3; + } else { + if (info.bpp < 16) + psize = (info.offset - info.extra_read - info.hsz) >> 2; + } + if (psize == 0) { + // accept some number of extra bytes after the header, but if the offset points either to before + // the header ends or implies a large amount of extra data, reject the file as malformed + int bytes_read_so_far = s->callback_already_read + (int)(s->img_buffer - s->img_buffer_original); + int header_limit = 1024; // max we actually read is below 256 bytes currently. + int extra_data_limit = 256 * 4; // what ordinarily goes here is a palette; 256 entries*4 bytes is its max size. + if (bytes_read_so_far <= 0 || bytes_read_so_far > header_limit) { + return stbi__errpuc("bad header", "Corrupt BMP"); + } + // we established that bytes_read_so_far is positive and sensible. + // the first half of this test rejects offsets that are either too small positives, or + // negative, and guarantees that info.offset >= bytes_read_so_far > 0. this in turn + // ensures the number computed in the second half of the test can't overflow. + if (info.offset < bytes_read_so_far || info.offset - bytes_read_so_far > extra_data_limit) { + return stbi__errpuc("bad offset", "Corrupt BMP"); + } else { + stbi__skip(s, info.offset - bytes_read_so_far); + } + } + + if (info.bpp == 24 && ma == 0xff000000) + s->img_n = 3; + else + s->img_n = ma ? 4 : 3; + if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 + target = req_comp; + else + target = s->img_n; // if they want monochrome, we'll post-convert + + // sanity-check size + if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) + return stbi__errpuc("too large", "Corrupt BMP"); + + out = (stbi_uc *)stbi__malloc_mad3(target, s->img_x, s->img_y, 0); + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + if (info.bpp < 16) { + int z = 0; + if (psize == 0 || psize > 256) { + STBI_FREE(out); + return stbi__errpuc("invalid", "Corrupt BMP"); + } + for (i = 0; i < psize; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + if (info.hsz != 12) + stbi__get8(s); + pal[i][3] = 255; + } + stbi__skip(s, info.offset - info.extra_read - info.hsz - psize * (info.hsz == 12 ? 3 : 4)); + if (info.bpp == 1) + width = (s->img_x + 7) >> 3; + else if (info.bpp == 4) + width = (s->img_x + 1) >> 1; + else if (info.bpp == 8) + width = s->img_x; + else { + STBI_FREE(out); + return stbi__errpuc("bad bpp", "Corrupt BMP"); + } + pad = (-width) & 3; + if (info.bpp == 1) { + for (j = 0; j < (int)s->img_y; ++j) { + int bit_offset = 7, v = stbi__get8(s); + for (i = 0; i < (int)s->img_x; ++i) { + int color = (v >> bit_offset) & 0x1; + out[z++] = pal[color][0]; + out[z++] = pal[color][1]; + out[z++] = pal[color][2]; + if (target == 4) + out[z++] = 255; + if (i + 1 == (int)s->img_x) + break; + if ((--bit_offset) < 0) { + bit_offset = 7; + v = stbi__get8(s); + } + } + stbi__skip(s, pad); + } + } else { + for (j = 0; j < (int)s->img_y; ++j) { + for (i = 0; i < (int)s->img_x; i += 2) { + int v = stbi__get8(s), v2 = 0; + if (info.bpp == 4) { + v2 = v & 15; + v >>= 4; + } + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) + out[z++] = 255; + if (i + 1 == (int)s->img_x) + break; + v = (info.bpp == 8) ? stbi__get8(s) : v2; + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) + out[z++] = 255; + } + stbi__skip(s, pad); + } + } + } else { + int rshift = 0, gshift = 0, bshift = 0, ashift = 0, rcount = 0, gcount = 0, bcount = 0, acount = 0; + int z = 0; + int easy = 0; + stbi__skip(s, info.offset - info.extra_read - info.hsz); + if (info.bpp == 24) + width = 3 * s->img_x; + else if (info.bpp == 16) + width = 2 * s->img_x; + else /* bpp = 32 and pad = 0 */ + width = 0; + pad = (-width) & 3; + if (info.bpp == 24) { + easy = 1; + } else if (info.bpp == 32) { + if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) + easy = 2; + } + if (!easy) { + if (!mr || !mg || !mb) { + STBI_FREE(out); + return stbi__errpuc("bad masks", "Corrupt BMP"); + } + // right shift amt to put high bit in position #7 + rshift = stbi__high_bit(mr) - 7; + rcount = stbi__bitcount(mr); + gshift = stbi__high_bit(mg) - 7; + gcount = stbi__bitcount(mg); + bshift = stbi__high_bit(mb) - 7; + bcount = stbi__bitcount(mb); + ashift = stbi__high_bit(ma) - 7; + acount = stbi__bitcount(ma); + if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { + STBI_FREE(out); + return stbi__errpuc("bad masks", "Corrupt BMP"); + } + } + for (j = 0; j < (int)s->img_y; ++j) { + if (easy) { + for (i = 0; i < (int)s->img_x; ++i) { + unsigned char a; + out[z + 2] = stbi__get8(s); + out[z + 1] = stbi__get8(s); + out[z + 0] = stbi__get8(s); + z += 3; + a = (easy == 2 ? stbi__get8(s) : 255); + all_a |= a; + if (target == 4) + out[z++] = a; + } + } else { + int bpp = info.bpp; + for (i = 0; i < (int)s->img_x; ++i) { + stbi__uint32 v = (bpp == 16 ? (stbi__uint32)stbi__get16le(s) : stbi__get32le(s)); + unsigned int a; + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); + a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); + all_a |= a; + if (target == 4) + out[z++] = STBI__BYTECAST(a); + } + } + stbi__skip(s, pad); + } + } + + // if alpha channel is all 0s, replace with all 255s + if (target == 4 && all_a == 0) + for (i = 4 * s->img_x * s->img_y - 1; i >= 0; i -= 4) + out[i] = 255; + + if (flip_vertically) { + stbi_uc t; + for (j = 0; j < (int)s->img_y >> 1; ++j) { + stbi_uc * p1 = out + j * s->img_x * target; + stbi_uc * p2 = out + (s->img_y - 1 - j) * s->img_x * target; + for (i = 0; i < (int)s->img_x * target; ++i) { + t = p1[i]; + p1[i] = p2[i]; + p2[i] = t; + } + } + } + + if (req_comp && req_comp != target) { + out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } + + *x = s->img_x; + *y = s->img_y; + if (comp) + *comp = s->img_n; + return out; +} +#endif + +// Targa Truevision - TGA +// by Jonathan Dummer +#ifndef STBI_NO_TGA +// returns STBI_rgb or whatever, 0 on error +static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int * is_rgb16) { + // only RGB or RGBA (incl. 16bit) or grey allowed + if (is_rgb16) + *is_rgb16 = 0; + switch (bits_per_pixel) { + case 8: + return STBI_grey; + case 16: + if (is_grey) + return STBI_grey_alpha; + // fallthrough + case 15: + if (is_rgb16) + *is_rgb16 = 1; + return STBI_rgb; + case 24: // fallthrough + case 32: + return bits_per_pixel / 8; + default: + return 0; + } +} + +static int stbi__tga_info(stbi__context * s, int * x, int * y, int * comp) { + int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, tga_colormap_bpp; + int sz, tga_colormap_type; + stbi__get8(s); // discard Offset + tga_colormap_type = stbi__get8(s); // colormap type + if (tga_colormap_type > 1) { + stbi__rewind(s); + return 0; // only RGB or indexed allowed + } + tga_image_type = stbi__get8(s); // image type + if (tga_colormap_type == 1) { // colormapped (paletted) image + if (tga_image_type != 1 && tga_image_type != 9) { + stbi__rewind(s); + return 0; + } + stbi__skip(s, 4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) { + stbi__rewind(s); + return 0; + } + stbi__skip(s, 4); // skip image x and y origin + tga_colormap_bpp = sz; + } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE + if ((tga_image_type != 2) && (tga_image_type != 3) && (tga_image_type != 10) && (tga_image_type != 11)) { + stbi__rewind(s); + return 0; // only RGB or grey allowed, +/- RLE + } + stbi__skip(s, 9); // skip colormap specification and image x/y origin + tga_colormap_bpp = 0; + } + tga_w = stbi__get16le(s); + if (tga_w < 1) { + stbi__rewind(s); + return 0; // test width + } + tga_h = stbi__get16le(s); + if (tga_h < 1) { + stbi__rewind(s); + return 0; // test height + } + tga_bits_per_pixel = stbi__get8(s); // bits per pixel + stbi__get8(s); // ignore alpha bits + if (tga_colormap_bpp != 0) { + if ((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { + // when using a colormap, tga_bits_per_pixel is the size of the indexes + // I don't think anything but 8 or 16bit indexes makes sense + stbi__rewind(s); + return 0; + } + tga_comp = stbi__tga_get_comp(tga_colormap_bpp, 0, NULL); + } else { + tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), NULL); + } + if (!tga_comp) { + stbi__rewind(s); + return 0; + } + if (x) + *x = tga_w; + if (y) + *y = tga_h; + if (comp) + *comp = tga_comp; + return 1; // seems to have passed everything +} + +static int stbi__tga_test(stbi__context * s) { + int res = 0; + int sz, tga_color_type; + stbi__get8(s); // discard Offset + tga_color_type = stbi__get8(s); // color type + if (tga_color_type > 1) + goto errorEnd; // only RGB or indexed allowed + sz = stbi__get8(s); // image type + if (tga_color_type == 1) { // colormapped (paletted) image + if (sz != 1 && sz != 9) + goto errorEnd; // colortype 1 demands image type 1 or 9 + stbi__skip(s, 4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) + goto errorEnd; + stbi__skip(s, 4); // skip image x and y origin + } else { // "normal" image w/o colormap + if ((sz != 2) && (sz != 3) && (sz != 10) && (sz != 11)) + goto errorEnd; // only RGB or grey allowed, +/- RLE + stbi__skip(s, 9); // skip colormap specification and image x/y origin + } + if (stbi__get16le(s) < 1) + goto errorEnd; // test width + if (stbi__get16le(s) < 1) + goto errorEnd; // test height + sz = stbi__get8(s); // bits per pixel + if ((tga_color_type == 1) && (sz != 8) && (sz != 16)) + goto errorEnd; // for colormapped images, bpp is size of an index + if ((sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32)) + goto errorEnd; + + res = 1; // if we got this far, everything's good and we can return 1 instead of 0 + +errorEnd: + stbi__rewind(s); + return res; +} + +// read 16bit value and convert to 24bit RGB +static void stbi__tga_read_rgb16(stbi__context * s, stbi_uc * out) { + stbi__uint16 px = (stbi__uint16)stbi__get16le(s); + stbi__uint16 fiveBitMask = 31; + // we have 3 channels with 5bits each + int r = (px >> 10) & fiveBitMask; + int g = (px >> 5) & fiveBitMask; + int b = px & fiveBitMask; + // Note that this saves the data in RGB(A) order, so it doesn't need to be swapped later + out[0] = (stbi_uc)((r * 255) / 31); + out[1] = (stbi_uc)((g * 255) / 31); + out[2] = (stbi_uc)((b * 255) / 31); + + // some people claim that the most significant bit might be used for alpha + // (possibly if an alpha-bit is set in the "image descriptor byte") + // but that only made 16bit test images completely translucent.. + // so let's treat all 15 and 16bit TGAs as RGB with no alpha. +} + +static void * stbi__tga_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { + // read in the TGA header stuff + int tga_offset = stbi__get8(s); + int tga_indexed = stbi__get8(s); + int tga_image_type = stbi__get8(s); + int tga_is_RLE = 0; + int tga_palette_start = stbi__get16le(s); + int tga_palette_len = stbi__get16le(s); + int tga_palette_bits = stbi__get8(s); + int tga_x_origin = stbi__get16le(s); + int tga_y_origin = stbi__get16le(s); + int tga_width = stbi__get16le(s); + int tga_height = stbi__get16le(s); + int tga_bits_per_pixel = stbi__get8(s); + int tga_comp, tga_rgb16 = 0; + int tga_inverted = stbi__get8(s); + // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused (useless?) + // image data + unsigned char * tga_data; + unsigned char * tga_palette = NULL; + int i, j; + unsigned char raw_data[4] = {0}; + int RLE_count = 0; + int RLE_repeating = 0; + int read_next_pixel = 1; + STBI_NOTUSED(ri); + STBI_NOTUSED(tga_x_origin); // @TODO + STBI_NOTUSED(tga_y_origin); // @TODO + + if (tga_height > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (tga_width > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + // do a tiny bit of precessing + if (tga_image_type >= 8) { + tga_image_type -= 8; + tga_is_RLE = 1; + } + tga_inverted = 1 - ((tga_inverted >> 5) & 1); + + // If I'm paletted, then I'll use the number of bits from the palette + if (tga_indexed) + tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); + else + tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), &tga_rgb16); + + if (!tga_comp) // shouldn't really happen, stbi__tga_test() should have ensured basic consistency + return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); + + // tga info + *x = tga_width; + *y = tga_height; + if (comp) + *comp = tga_comp; + + if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) + return stbi__errpuc("too large", "Corrupt TGA"); + + tga_data = (unsigned char *)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); + if (!tga_data) + return stbi__errpuc("outofmem", "Out of memory"); + + // skip to the data's starting position (offset usually = 0) + stbi__skip(s, tga_offset); + + if (!tga_indexed && !tga_is_RLE && !tga_rgb16) { + for (i = 0; i < tga_height; ++i) { + int row = tga_inverted ? tga_height - i - 1 : i; + stbi_uc * tga_row = tga_data + row * tga_width * tga_comp; + stbi__getn(s, tga_row, tga_width * tga_comp); + } + } else { + // do I need to load a palette? + if (tga_indexed) { + if (tga_palette_len == 0) { /* you have to have at least one entry! */ + STBI_FREE(tga_data); + return stbi__errpuc("bad palette", "Corrupt TGA"); + } + + // any data to skip? (offset usually = 0) + stbi__skip(s, tga_palette_start); + // load the palette + tga_palette = (unsigned char *)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); + if (!tga_palette) { + STBI_FREE(tga_data); + return stbi__errpuc("outofmem", "Out of memory"); + } + if (tga_rgb16) { + stbi_uc * pal_entry = tga_palette; + STBI_ASSERT(tga_comp == STBI_rgb); + for (i = 0; i < tga_palette_len; ++i) { + stbi__tga_read_rgb16(s, pal_entry); + pal_entry += tga_comp; + } + } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { + STBI_FREE(tga_data); + STBI_FREE(tga_palette); + return stbi__errpuc("bad palette", "Corrupt TGA"); + } + } + // load the data + for (i = 0; i < tga_width * tga_height; ++i) { + // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? + if (tga_is_RLE) { + if (RLE_count == 0) { + // yep, get the next byte as a RLE command + int RLE_cmd = stbi__get8(s); + RLE_count = 1 + (RLE_cmd & 127); + RLE_repeating = RLE_cmd >> 7; + read_next_pixel = 1; + } else if (!RLE_repeating) { + read_next_pixel = 1; + } + } else { + read_next_pixel = 1; + } + // OK, if I need to read a pixel, do it now + if (read_next_pixel) { + // load however much data we did have + if (tga_indexed) { + // read in index, then perform the lookup + int pal_idx = (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); + if (pal_idx >= tga_palette_len) { + // invalid index + pal_idx = 0; + } + pal_idx *= tga_comp; + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = tga_palette[pal_idx + j]; + } + } else if (tga_rgb16) { + STBI_ASSERT(tga_comp == STBI_rgb); + stbi__tga_read_rgb16(s, raw_data); + } else { + // read in the data raw + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = stbi__get8(s); + } + } + // clear the reading flag for the next pixel + read_next_pixel = 0; + } // end of reading a pixel + + // copy data + for (j = 0; j < tga_comp; ++j) + tga_data[i * tga_comp + j] = raw_data[j]; + + // in case we're in RLE mode, keep counting down + --RLE_count; + } + // do I need to invert the image? + if (tga_inverted) { + for (j = 0; j * 2 < tga_height; ++j) { + int index1 = j * tga_width * tga_comp; + int index2 = (tga_height - 1 - j) * tga_width * tga_comp; + for (i = tga_width * tga_comp; i > 0; --i) { + unsigned char temp = tga_data[index1]; + tga_data[index1] = tga_data[index2]; + tga_data[index2] = temp; + ++index1; + ++index2; + } + } + } + // clear my palette, if I had one + if (tga_palette != NULL) { + STBI_FREE(tga_palette); + } + } + + // swap RGB - if the source data was RGB16, it already is in the right order + if (tga_comp >= 3 && !tga_rgb16) { + unsigned char * tga_pixel = tga_data; + for (i = 0; i < tga_width * tga_height; ++i) { + unsigned char temp = tga_pixel[0]; + tga_pixel[0] = tga_pixel[2]; + tga_pixel[2] = temp; + tga_pixel += tga_comp; + } + } + + // convert to target component count + if (req_comp && req_comp != tga_comp) + tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, tga_height); + + // the things I do to get rid of an error message, and yet keep + // Microsoft's C compilers happy... [8^( + tga_palette_start = tga_palette_len = tga_palette_bits = tga_x_origin = tga_y_origin = 0; + STBI_NOTUSED(tga_palette_start); + // OK, done + return tga_data; +} +#endif + +// ************************************************************************************************* +// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, tweaked by STB + +#ifndef STBI_NO_PSD +static int stbi__psd_test(stbi__context * s) { + int r = (stbi__get32be(s) == 0x38425053); + stbi__rewind(s); + return r; +} + +static int stbi__psd_decode_rle(stbi__context * s, stbi_uc * p, int pixelCount) { + int count, nleft, len; + + count = 0; + while ((nleft = pixelCount - count) > 0) { + len = stbi__get8(s); + if (len == 128) { + // No-op. + } else if (len < 128) { + // Copy next len+1 bytes literally. + len++; + if (len > nleft) + return 0; // corrupt data + count += len; + while (len) { + *p = stbi__get8(s); + p += 4; + len--; + } + } else if (len > 128) { + stbi_uc val; + // Next -len+1 bytes in the dest are replicated from next source byte. + // (Interpret len as a negative 8-bit int.) + len = 257 - len; + if (len > nleft) + return 0; // corrupt data + val = stbi__get8(s); + count += len; + while (len) { + *p = val; + p += 4; + len--; + } + } + } + + return 1; +} + +static void * stbi__psd_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri, int bpc) { + int pixelCount; + int channelCount, compression; + int channel, i; + int bitdepth; + int w, h; + stbi_uc * out; + STBI_NOTUSED(ri); + + // Check identifier + if (stbi__get32be(s) != 0x38425053) // "8BPS" + return stbi__errpuc("not PSD", "Corrupt PSD image"); + + // Check file type version. + if (stbi__get16be(s) != 1) + return stbi__errpuc("wrong version", "Unsupported version of PSD image"); + + // Skip 6 reserved bytes. + stbi__skip(s, 6); + + // Read the number of channels (R, G, B, A, etc). + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) + return stbi__errpuc("wrong channel count", "Unsupported number of channels in PSD image"); + + // Read the rows and columns of the image. + h = stbi__get32be(s); + w = stbi__get32be(s); + + if (h > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (w > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + // Make sure the depth is 8 bits. + bitdepth = stbi__get16be(s); + if (bitdepth != 8 && bitdepth != 16) + return stbi__errpuc("unsupported bit depth", "PSD bit depth is not 8 or 16 bit"); + + // Make sure the color mode is RGB. + // Valid options are: + // 0: Bitmap + // 1: Grayscale + // 2: Indexed color + // 3: RGB color + // 4: CMYK color + // 7: Multichannel + // 8: Duotone + // 9: Lab color + if (stbi__get16be(s) != 3) + return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); + + // Skip the Mode Data. (It's the palette for indexed color; other info for other modes.) + stbi__skip(s, stbi__get32be(s)); + + // Skip the image resources. (resolution, pen tool paths, etc) + stbi__skip(s, stbi__get32be(s)); + + // Skip the reserved data. + stbi__skip(s, stbi__get32be(s)); + + // Find out if the data is compressed. + // Known values: + // 0: no compression + // 1: RLE compressed + compression = stbi__get16be(s); + if (compression > 1) + return stbi__errpuc("bad compression", "PSD has an unknown compression format"); + + // Check size + if (!stbi__mad3sizes_valid(4, w, h, 0)) + return stbi__errpuc("too large", "Corrupt PSD"); + + // Create the destination image. + + if (!compression && bitdepth == 16 && bpc == 16) { + out = (stbi_uc *)stbi__malloc_mad3(8, w, h, 0); + ri->bits_per_channel = 16; + } else + out = (stbi_uc *)stbi__malloc(4 * w * h); + + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + pixelCount = w * h; + + // Initialize the data to zero. + // memset( out, 0, pixelCount * 4 ); + + // Finally, the image data. + if (compression) { + // RLE as used by .PSD and .TIFF + // Loop until you get the number of unpacked bytes you are expecting: + // Read the next source byte into n. + // If n is between 0 and 127 inclusive, copy the next n+1 bytes literally. + // Else if n is between -127 and -1 inclusive, copy the next byte -n+1 times. + // Else if n is 128, noop. + // Endloop + + // The RLE-compressed data is preceded by a 2-byte data count for each row in the data, + // which we're going to just skip. + stbi__skip(s, h * channelCount * 2); + + // Read the RLE data by channel. + for (channel = 0; channel < 4; channel++) { + stbi_uc * p; + + p = out + channel; + if (channel >= channelCount) { + // Fill this channel with default data. + for (i = 0; i < pixelCount; i++, p += 4) + *p = (channel == 3 ? 255 : 0); + } else { + // Read the RLE data. + if (!stbi__psd_decode_rle(s, p, pixelCount)) { + STBI_FREE(out); + return stbi__errpuc("corrupt", "bad RLE data"); + } + } + } + } else { + // We're at the raw image data. It's each channel in order (Red, Green, Blue, Alpha, ...) + // where each channel consists of an 8-bit (or 16-bit) value for each pixel in the image. + + // Read the data by channel. + for (channel = 0; channel < 4; channel++) { + if (channel >= channelCount) { + // Fill this channel with default data. + if (bitdepth == 16 && bpc == 16) { + stbi__uint16 * q = ((stbi__uint16 *)out) + channel; + stbi__uint16 val = channel == 3 ? 65535 : 0; + for (i = 0; i < pixelCount; i++, q += 4) + *q = val; + } else { + stbi_uc * p = out + channel; + stbi_uc val = channel == 3 ? 255 : 0; + for (i = 0; i < pixelCount; i++, p += 4) + *p = val; + } + } else { + if (ri->bits_per_channel == 16) { // output bpc + stbi__uint16 * q = ((stbi__uint16 *)out) + channel; + for (i = 0; i < pixelCount; i++, q += 4) + *q = (stbi__uint16)stbi__get16be(s); + } else { + stbi_uc * p = out + channel; + if (bitdepth == 16) { // input bpc + for (i = 0; i < pixelCount; i++, p += 4) + *p = (stbi_uc)(stbi__get16be(s) >> 8); + } else { + for (i = 0; i < pixelCount; i++, p += 4) + *p = stbi__get8(s); + } + } + } + } + } + + // remove weird white matte from PSD + if (channelCount >= 4) { + if (ri->bits_per_channel == 16) { + for (i = 0; i < w * h; ++i) { + stbi__uint16 * pixel = (stbi__uint16 *)out + 4 * i; + if (pixel[3] != 0 && pixel[3] != 65535) { + float a = pixel[3] / 65535.0f; + float ra = 1.0f / a; + float inv_a = 65535.0f * (1 - ra); + pixel[0] = (stbi__uint16)(pixel[0] * ra + inv_a); + pixel[1] = (stbi__uint16)(pixel[1] * ra + inv_a); + pixel[2] = (stbi__uint16)(pixel[2] * ra + inv_a); + } + } + } else { + for (i = 0; i < w * h; ++i) { + unsigned char * pixel = out + 4 * i; + if (pixel[3] != 0 && pixel[3] != 255) { + float a = pixel[3] / 255.0f; + float ra = 1.0f / a; + float inv_a = 255.0f * (1 - ra); + pixel[0] = (unsigned char)(pixel[0] * ra + inv_a); + pixel[1] = (unsigned char)(pixel[1] * ra + inv_a); + pixel[2] = (unsigned char)(pixel[2] * ra + inv_a); + } + } + } + } + + // convert to desired output format + if (req_comp && req_comp != 4) { + if (ri->bits_per_channel == 16) + out = (stbi_uc *)stbi__convert_format16((stbi__uint16 *)out, 4, req_comp, w, h); + else + out = stbi__convert_format(out, 4, req_comp, w, h); + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } + + if (comp) + *comp = 4; + *y = h; + *x = w; + + return out; +} +#endif + +// ************************************************************************************************* +// Softimage PIC loader +// by Tom Seddon +// +// See http://softimage.wiki.softimage.com/index.php/INFO:_PIC_file_format +// See http://ozviz.wasp.uwa.edu.au/~pbourke/dataformats/softimagepic/ + +#ifndef STBI_NO_PIC +static int stbi__pic_is4(stbi__context * s, const char * str) { + int i; + for (i = 0; i < 4; ++i) + if (stbi__get8(s) != (stbi_uc)str[i]) + return 0; + + return 1; +} + +static int stbi__pic_test_core(stbi__context * s) { + int i; + + if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) + return 0; + + for (i = 0; i < 84; ++i) + stbi__get8(s); + + if (!stbi__pic_is4(s, "PICT")) + return 0; + + return 1; +} + +typedef struct { + stbi_uc size, type, channel; +} stbi__pic_packet; + +static stbi_uc * stbi__readval(stbi__context * s, int channel, stbi_uc * dest) { + int mask = 0x80, i; + + for (i = 0; i < 4; ++i, mask >>= 1) { + if (channel & mask) { + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "PIC file too short"); + dest[i] = stbi__get8(s); + } + } + + return dest; +} + +static void stbi__copyval(int channel, stbi_uc * dest, const stbi_uc * src) { + int mask = 0x80, i; + + for (i = 0; i < 4; ++i, mask >>= 1) + if (channel & mask) + dest[i] = src[i]; +} + +static stbi_uc * stbi__pic_load_core(stbi__context * s, int width, int height, int * comp, stbi_uc * result) { + int act_comp = 0, num_packets = 0, y, chained; + stbi__pic_packet packets[10]; + + // this will (should...) cater for even some bizarre stuff like having data + // for the same channel in multiple packets. + do { + stbi__pic_packet * packet; + + if (num_packets == sizeof(packets) / sizeof(packets[0])) + return stbi__errpuc("bad format", "too many packets"); + + packet = &packets[num_packets++]; + + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); + + act_comp |= packet->channel; + + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (reading packets)"); + if (packet->size != 8) + return stbi__errpuc("bad format", "packet isn't 8bpp"); + } while (chained); + + *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? + + for (y = 0; y < height; ++y) { + int packet_idx; + + for (packet_idx = 0; packet_idx < num_packets; ++packet_idx) { + stbi__pic_packet * packet = &packets[packet_idx]; + stbi_uc * dest = result + y * width * 4; + + switch (packet->type) { + default: + return stbi__errpuc("bad format", "packet has bad compression type"); + + case 0: { // uncompressed + int x; + + for (x = 0; x < width; ++x, dest += 4) + if (!stbi__readval(s, packet->channel, dest)) + return 0; + break; + } + + case 1: // Pure RLE + { + int left = width, i; + + while (left > 0) { + stbi_uc count, value[4]; + + count = stbi__get8(s); + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (pure read count)"); + + if (count > left) + count = (stbi_uc)left; + + if (!stbi__readval(s, packet->channel, value)) + return 0; + + for (i = 0; i < count; ++i, dest += 4) + stbi__copyval(packet->channel, dest, value); + left -= count; + } + } break; + + case 2: { // Mixed RLE + int left = width; + while (left > 0) { + int count = stbi__get8(s), i; + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (mixed read count)"); + + if (count >= 128) { // Repeated + stbi_uc value[4]; + + if (count == 128) + count = stbi__get16be(s); + else + count -= 127; + if (count > left) + return stbi__errpuc("bad file", "scanline overrun"); + + if (!stbi__readval(s, packet->channel, value)) + return 0; + + for (i = 0; i < count; ++i, dest += 4) + stbi__copyval(packet->channel, dest, value); + } else { // Raw + ++count; + if (count > left) + return stbi__errpuc("bad file", "scanline overrun"); + + for (i = 0; i < count; ++i, dest += 4) + if (!stbi__readval(s, packet->channel, dest)) + return 0; + } + left -= count; + } + break; + } + } + } + } + + return result; +} + +static void * stbi__pic_load(stbi__context * s, int * px, int * py, int * comp, int req_comp, stbi__result_info * ri) { + stbi_uc * result; + int i, x, y, internal_comp; + STBI_NOTUSED(ri); + + if (!comp) + comp = &internal_comp; + + for (i = 0; i < 92; ++i) + stbi__get8(s); + + x = stbi__get16be(s); + y = stbi__get16be(s); + + if (y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + if (stbi__at_eof(s)) + return stbi__errpuc("bad file", "file too short (pic header)"); + if (!stbi__mad3sizes_valid(x, y, 4, 0)) + return stbi__errpuc("too large", "PIC image too large to decode"); + + stbi__get32be(s); // skip `ratio' + stbi__get16be(s); // skip `fields' + stbi__get16be(s); // skip `pad' + + // intermediate buffer is RGBA + result = (stbi_uc *)stbi__malloc_mad3(x, y, 4, 0); + if (!result) + return stbi__errpuc("outofmem", "Out of memory"); + memset(result, 0xff, x * y * 4); + + if (!stbi__pic_load_core(s, x, y, comp, result)) { + STBI_FREE(result); + result = 0; + } + *px = x; + *py = y; + if (req_comp == 0) + req_comp = *comp; + result = stbi__convert_format(result, 4, req_comp, x, y); + + return result; +} + +static int stbi__pic_test(stbi__context * s) { + int r = stbi__pic_test_core(s); + stbi__rewind(s); + return r; +} +#endif + +// ************************************************************************************************* +// GIF loader -- public domain by Jean-Marc Lienher -- simplified/shrunk by stb + +#ifndef STBI_NO_GIF +typedef struct { + stbi__int16 prefix; + stbi_uc first; + stbi_uc suffix; +} stbi__gif_lzw; + +typedef struct { + int w, h; + stbi_uc * out; // output buffer (always 4 components) + stbi_uc * background; // The current "background" as far as a gif is concerned + stbi_uc * history; + int flags, bgindex, ratio, transparent, eflags; + stbi_uc pal[256][4]; + stbi_uc lpal[256][4]; + stbi__gif_lzw codes[8192]; + stbi_uc * color_table; + int parse, step; + int lflags; + int start_x, start_y; + int max_x, max_y; + int cur_x, cur_y; + int line_size; + int delay; +} stbi__gif; + +static int stbi__gif_test_raw(stbi__context * s) { + int sz; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') + return 0; + sz = stbi__get8(s); + if (sz != '9' && sz != '7') + return 0; + if (stbi__get8(s) != 'a') + return 0; + return 1; +} + +static int stbi__gif_test(stbi__context * s) { + int r = stbi__gif_test_raw(s); + stbi__rewind(s); + return r; +} + +static void stbi__gif_parse_colortable(stbi__context * s, stbi_uc pal[256][4], int num_entries, int transp) { + int i; + for (i = 0; i < num_entries; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + pal[i][3] = transp == i ? 0 : 255; + } +} + +static int stbi__gif_header(stbi__context * s, stbi__gif * g, int * comp, int is_info) { + stbi_uc version; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') + return stbi__err("not GIF", "Corrupt GIF"); + + version = stbi__get8(s); + if (version != '7' && version != '9') + return stbi__err("not GIF", "Corrupt GIF"); + if (stbi__get8(s) != 'a') + return stbi__err("not GIF", "Corrupt GIF"); + + stbi__g_failure_reason = ""; + g->w = stbi__get16le(s); + g->h = stbi__get16le(s); + g->flags = stbi__get8(s); + g->bgindex = stbi__get8(s); + g->ratio = stbi__get8(s); + g->transparent = -1; + + if (g->w > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + if (g->h > STBI_MAX_DIMENSIONS) + return stbi__err("too large", "Very large image (corrupt?)"); + + if (comp != 0) + *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the comments + + if (is_info) + return 1; + + if (g->flags & 0x80) + stbi__gif_parse_colortable(s, g->pal, 2 << (g->flags & 7), -1); + + return 1; +} + +static int stbi__gif_info_raw(stbi__context * s, int * x, int * y, int * comp) { + stbi__gif * g = (stbi__gif *)stbi__malloc(sizeof(stbi__gif)); + if (!g) + return stbi__err("outofmem", "Out of memory"); + if (!stbi__gif_header(s, g, comp, 1)) { + STBI_FREE(g); + stbi__rewind(s); + return 0; + } + if (x) + *x = g->w; + if (y) + *y = g->h; + STBI_FREE(g); + return 1; +} + +static void stbi__out_gif_code(stbi__gif * g, stbi__uint16 code) { + stbi_uc *p, *c; + int idx; + + // recurse to decode the prefixes, since the linked-list is backwards, + // and working backwards through an interleaved image would be nasty + if (g->codes[code].prefix >= 0) + stbi__out_gif_code(g, g->codes[code].prefix); + + if (g->cur_y >= g->max_y) + return; + + idx = g->cur_x + g->cur_y; + p = &g->out[idx]; + g->history[idx / 4] = 1; + + c = &g->color_table[g->codes[code].suffix * 4]; + if (c[3] > 128) { // don't render transparent pixels; + p[0] = c[2]; + p[1] = c[1]; + p[2] = c[0]; + p[3] = c[3]; + } + g->cur_x += 4; + + if (g->cur_x >= g->max_x) { + g->cur_x = g->start_x; + g->cur_y += g->step; + + while (g->cur_y >= g->max_y && g->parse > 0) { + g->step = (1 << g->parse) * g->line_size; + g->cur_y = g->start_y + (g->step >> 1); + --g->parse; + } + } +} + +static stbi_uc * stbi__process_gif_raster(stbi__context * s, stbi__gif * g) { + stbi_uc lzw_cs; + stbi__int32 len, init_code; + stbi__uint32 first; + stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; + stbi__gif_lzw * p; + + lzw_cs = stbi__get8(s); + if (lzw_cs > 12) + return NULL; + clear = 1 << lzw_cs; + first = 1; + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + bits = 0; + valid_bits = 0; + for (init_code = 0; init_code < clear; init_code++) { + g->codes[init_code].prefix = -1; + g->codes[init_code].first = (stbi_uc)init_code; + g->codes[init_code].suffix = (stbi_uc)init_code; + } + + // support no starting clear code + avail = clear + 2; + oldcode = -1; + + len = 0; + for (;;) { + if (valid_bits < codesize) { + if (len == 0) { + len = stbi__get8(s); // start new block + if (len == 0) + return g->out; + } + --len; + bits |= (stbi__int32)stbi__get8(s) << valid_bits; + valid_bits += 8; + } else { + stbi__int32 code = bits & codemask; + bits >>= codesize; + valid_bits -= codesize; + // @OPTIMIZE: is there some way we can accelerate the non-clear path? + if (code == clear) { // clear code + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + avail = clear + 2; + oldcode = -1; + first = 0; + } else if (code == clear + 1) { // end of stream code + stbi__skip(s, len); + while ((len = stbi__get8(s)) > 0) + stbi__skip(s, len); + return g->out; + } else if (code <= avail) { + if (first) { + return stbi__errpuc("no clear code", "Corrupt GIF"); + } + + if (oldcode >= 0) { + p = &g->codes[avail++]; + if (avail > 8192) { + return stbi__errpuc("too many codes", "Corrupt GIF"); + } + + p->prefix = (stbi__int16)oldcode; + p->first = g->codes[oldcode].first; + p->suffix = (code == avail) ? p->first : g->codes[code].first; + } else if (code == avail) + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + + stbi__out_gif_code(g, (stbi__uint16)code); + + if ((avail & codemask) == 0 && avail <= 0x0FFF) { + codesize++; + codemask = (1 << codesize) - 1; + } + + oldcode = code; + } else { + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + } + } + } +} + +// this function is designed to support animated gifs, although stb_image doesn't support it +// two back is the image from two frames ago, used for a very specific disposal format +static stbi_uc * stbi__gif_load_next(stbi__context * s, stbi__gif * g, int * comp, int req_comp, stbi_uc * two_back) { + int dispose; + int first_frame; + int pi; + int pcount; + STBI_NOTUSED(req_comp); + + // on first frame, any non-written pixels get the background colour (non-transparent) + first_frame = 0; + if (g->out == 0) { + if (!stbi__gif_header(s, g, comp, 0)) + return 0; // stbi__g_failure_reason set by stbi__gif_header + if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) + return stbi__errpuc("too large", "GIF image is too large"); + pcount = g->w * g->h; + g->out = (stbi_uc *)stbi__malloc(4 * pcount); + g->background = (stbi_uc *)stbi__malloc(4 * pcount); + g->history = (stbi_uc *)stbi__malloc(pcount); + if (!g->out || !g->background || !g->history) + return stbi__errpuc("outofmem", "Out of memory"); + + // image is treated as "transparent" at the start - ie, nothing overwrites the current background; + // background colour is only used for pixels that are not rendered first frame, after that "background" + // color refers to the color that was there the previous frame. + memset(g->out, 0x00, 4 * pcount); + memset(g->background, 0x00, 4 * pcount); // state of the background (starts transparent) + memset(g->history, 0x00, pcount); // pixels that were affected previous frame + first_frame = 1; + } else { + // second frame - how do we dispose of the previous one? + dispose = (g->eflags & 0x1C) >> 2; + pcount = g->w * g->h; + + if ((dispose == 3) && (two_back == 0)) { + dispose = 2; // if I don't have an image to revert back to, default to the old background + } + + if (dispose == 3) { // use previous graphic + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy(&g->out[pi * 4], &two_back[pi * 4], 4); + } + } + } else if (dispose == 2) { + // restore what was changed last frame to background before that frame; + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy(&g->out[pi * 4], &g->background[pi * 4], 4); + } + } + } else { + // This is a non-disposal case eithe way, so just + // leave the pixels as is, and they will become the new background + // 1: do not dispose + // 0: not specified. + } + + // background is what out is after the undoing of the previou frame; + memcpy(g->background, g->out, 4 * g->w * g->h); + } + + // clear my history; + memset(g->history, 0x00, g->w * g->h); // pixels that were affected previous frame + + for (;;) { + int tag = stbi__get8(s); + switch (tag) { + case 0x2C: /* Image Descriptor */ + { + stbi__int32 x, y, w, h; + stbi_uc * o; + + x = stbi__get16le(s); + y = stbi__get16le(s); + w = stbi__get16le(s); + h = stbi__get16le(s); + if (((x + w) > (g->w)) || ((y + h) > (g->h))) + return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); + + g->line_size = g->w * 4; + g->start_x = x * 4; + g->start_y = y * g->line_size; + g->max_x = g->start_x + w * 4; + g->max_y = g->start_y + h * g->line_size; + g->cur_x = g->start_x; + g->cur_y = g->start_y; + + // if the width of the specified rectangle is 0, that means + // we may not see *any* pixels or the image is malformed; + // to make sure this is caught, move the current y down to + // max_y (which is what out_gif_code checks). + if (w == 0) + g->cur_y = g->max_y; + + g->lflags = stbi__get8(s); + + if (g->lflags & 0x40) { + g->step = 8 * g->line_size; // first interlaced spacing + g->parse = 3; + } else { + g->step = g->line_size; + g->parse = 0; + } + + if (g->lflags & 0x80) { + stbi__gif_parse_colortable(s, g->lpal, 2 << (g->lflags & 7), g->eflags & 0x01 ? g->transparent : -1); + g->color_table = (stbi_uc *)g->lpal; + } else if (g->flags & 0x80) { + g->color_table = (stbi_uc *)g->pal; + } else + return stbi__errpuc("missing color table", "Corrupt GIF"); + + o = stbi__process_gif_raster(s, g); + if (!o) + return NULL; + + // if this was the first frame, + pcount = g->w * g->h; + if (first_frame && (g->bgindex > 0)) { + // if first frame, any pixel not drawn to gets the background color + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi] == 0) { + g->pal[g->bgindex][3] = + 255; // just in case it was made transparent, undo that; It will be reset next frame if need be; + memcpy(&g->out[pi * 4], &g->pal[g->bgindex], 4); + } + } + } + + return o; + } + + case 0x21: // Comment Extension. + { + int len; + int ext = stbi__get8(s); + if (ext == 0xF9) { // Graphic Control Extension. + len = stbi__get8(s); + if (len == 4) { + g->eflags = stbi__get8(s); + g->delay = 10 * stbi__get16le(s); // delay - 1/100th of a second, saving as 1/1000ths. + + // unset old transparent + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 255; + } + if (g->eflags & 0x01) { + g->transparent = stbi__get8(s); + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 0; + } + } else { + // don't need transparent + stbi__skip(s, 1); + g->transparent = -1; + } + } else { + stbi__skip(s, len); + break; + } + } + while ((len = stbi__get8(s)) != 0) { + stbi__skip(s, len); + } + break; + } + + case 0x3B: // gif stream termination code + return (stbi_uc *)s; // using '1' causes warning on some compilers + + default: + return stbi__errpuc("unknown code", "Corrupt GIF"); + } + } +} + +static void * stbi__load_gif_main_outofmem(stbi__gif * g, stbi_uc * out, int ** delays) { + STBI_FREE(g->out); + STBI_FREE(g->history); + STBI_FREE(g->background); + + if (out) + STBI_FREE(out); + if (delays && *delays) + STBI_FREE(*delays); + return stbi__errpuc("outofmem", "Out of memory"); +} + +static void * stbi__load_gif_main(stbi__context * s, int ** delays, int * x, int * y, int * z, int * comp, int req_comp) { + if (stbi__gif_test(s)) { + int layers = 0; + stbi_uc * u = 0; + stbi_uc * out = 0; + stbi_uc * two_back = 0; + stbi__gif g; + int stride; + int out_size = 0; + int delays_size = 0; + + STBI_NOTUSED(out_size); + STBI_NOTUSED(delays_size); + + memset(&g, 0, sizeof(g)); + if (delays) { + *delays = 0; + } + + do { + u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); + if (u == (stbi_uc *)s) + u = 0; // end of animated gif marker + + if (u) { + *x = g.w; + *y = g.h; + ++layers; + stride = g.w * g.h * 4; + + if (out) { + void * tmp = (stbi_uc *)STBI_REALLOC_SIZED(out, out_size, layers * stride); + if (!tmp) + return stbi__load_gif_main_outofmem(&g, out, delays); + else { + out = (stbi_uc *)tmp; + out_size = layers * stride; + } + + if (delays) { + int * new_delays = (int *)STBI_REALLOC_SIZED(*delays, delays_size, sizeof(int) * layers); + if (!new_delays) + return stbi__load_gif_main_outofmem(&g, out, delays); + *delays = new_delays; + delays_size = layers * sizeof(int); + } + } else { + out = (stbi_uc *)stbi__malloc(layers * stride); + if (!out) + return stbi__load_gif_main_outofmem(&g, out, delays); + out_size = layers * stride; + if (delays) { + *delays = (int *)stbi__malloc(layers * sizeof(int)); + if (!*delays) + return stbi__load_gif_main_outofmem(&g, out, delays); + delays_size = layers * sizeof(int); + } + } + memcpy(out + ((layers - 1) * stride), u, stride); + if (layers >= 2) { + two_back = out - 2 * stride; + } + + if (delays) { + (*delays)[layers - 1U] = g.delay; + } + } + } while (u != 0); + + // free temp buffer; + STBI_FREE(g.out); + STBI_FREE(g.history); + STBI_FREE(g.background); + + // do the final conversion after loading everything; + if (req_comp && req_comp != 4) + out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); + + *z = layers; + return out; + } else { + return stbi__errpuc("not GIF", "Image was not as a gif type."); + } +} + +static void * stbi__gif_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { + stbi_uc * u = 0; + stbi__gif g; + memset(&g, 0, sizeof(g)); + STBI_NOTUSED(ri); + + u = stbi__gif_load_next(s, &g, comp, req_comp, 0); + if (u == (stbi_uc *)s) + u = 0; // end of animated gif marker + if (u) { + *x = g.w; + *y = g.h; + + // moved conversion to after successful load so that the same + // can be done for multiple frames. + if (req_comp && req_comp != 4) + u = stbi__convert_format(u, 4, req_comp, g.w, g.h); + } else if (g.out) { + // if there was an error and we allocated an image buffer, free it! + STBI_FREE(g.out); + } + + // free buffers needed for multiple frame loading; + STBI_FREE(g.history); + STBI_FREE(g.background); + + return u; +} + +static int stbi__gif_info(stbi__context * s, int * x, int * y, int * comp) { return stbi__gif_info_raw(s, x, y, comp); } +#endif + +// ************************************************************************************************* +// Radiance RGBE HDR loader +// originally by Nicolas Schulz +#ifndef STBI_NO_HDR +static int stbi__hdr_test_core(stbi__context * s, const char * signature) { + int i; + for (i = 0; signature[i]; ++i) + if (stbi__get8(s) != signature[i]) + return 0; + stbi__rewind(s); + return 1; +} + +static int stbi__hdr_test(stbi__context * s) { + int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); + stbi__rewind(s); + if (!r) { + r = stbi__hdr_test_core(s, "#?RGBE\n"); + stbi__rewind(s); + } + return r; +} + +#define STBI__HDR_BUFLEN 1024 +static char * stbi__hdr_gettoken(stbi__context * z, char * buffer) { + int len = 0; + char c = '\0'; + + c = (char)stbi__get8(z); + + while (!stbi__at_eof(z) && c != '\n') { + buffer[len++] = c; + if (len == STBI__HDR_BUFLEN - 1) { + // flush to end of line + while (!stbi__at_eof(z) && stbi__get8(z) != '\n') + ; + break; + } + c = (char)stbi__get8(z); + } + + buffer[len] = 0; + return buffer; +} + +static void stbi__hdr_convert(float * output, stbi_uc * input, int req_comp) { + if (input[3] != 0) { + float f1; + // Exponent + f1 = (float)ldexp(1.0f, input[3] - (int)(128 + 8)); + if (req_comp <= 2) + output[0] = (input[0] + input[1] + input[2]) * f1 / 3; + else { + output[0] = input[0] * f1; + output[1] = input[1] * f1; + output[2] = input[2] * f1; + } + if (req_comp == 2) + output[1] = 1; + if (req_comp == 4) + output[3] = 1; + } else { + switch (req_comp) { + case 4: + output[3] = 1; /* fallthrough */ + case 3: + output[0] = output[1] = output[2] = 0; + break; + case 2: + output[1] = 1; /* fallthrough */ + case 1: + output[0] = 0; + break; + } + } +} + +static float * stbi__hdr_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { + char buffer[STBI__HDR_BUFLEN]; + char * token; + int valid = 0; + int width, height; + stbi_uc * scanline; + float * hdr_data; + int len; + unsigned char count, value; + int i, j, k, c1, c2, z; + const char * headerToken; + STBI_NOTUSED(ri); + + // Check identifier + headerToken = stbi__hdr_gettoken(s, buffer); + if (strcmp(headerToken, "#?RADIANCE") != 0 && strcmp(headerToken, "#?RGBE") != 0) + return stbi__errpf("not HDR", "Corrupt HDR image"); + + // Parse header + for (;;) { + token = stbi__hdr_gettoken(s, buffer); + if (token[0] == 0) + break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) + valid = 1; + } + + if (!valid) + return stbi__errpf("unsupported format", "Unsupported HDR format"); + + // Parse width and height + // can't use sscanf() if we're not using stdio! + token = stbi__hdr_gettoken(s, buffer); + if (strncmp(token, "-Y ", 3)) + return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + height = (int)strtol(token, &token, 10); + while (*token == ' ') + ++token; + if (strncmp(token, "+X ", 3)) + return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + width = (int)strtol(token, NULL, 10); + + if (height > STBI_MAX_DIMENSIONS) + return stbi__errpf("too large", "Very large image (corrupt?)"); + if (width > STBI_MAX_DIMENSIONS) + return stbi__errpf("too large", "Very large image (corrupt?)"); + + *x = width; + *y = height; + + if (comp) + *comp = 3; + if (req_comp == 0) + req_comp = 3; + + if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) + return stbi__errpf("too large", "HDR image is too large"); + + // Read data + hdr_data = (float *)stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); + if (!hdr_data) + return stbi__errpf("outofmem", "Out of memory"); + + // Load image data + // image data is stored as some number of sca + if (width < 8 || width >= 32768) { + // Read flat data + for (j = 0; j < height; ++j) { + for (i = 0; i < width; ++i) { + stbi_uc rgbe[4]; + main_decode_loop: + stbi__getn(s, rgbe, 4); + stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, req_comp); + } + } + } else { + // Read RLE-encoded data + scanline = NULL; + + for (j = 0; j < height; ++j) { + c1 = stbi__get8(s); + c2 = stbi__get8(s); + len = stbi__get8(s); + if (c1 != 2 || c2 != 2 || (len & 0x80)) { + // not run-length encoded, so we have to actually use THIS data as a decoded + // pixel (note this can't be a valid pixel--one of RGB must be >= 128) + stbi_uc rgbe[4]; + rgbe[0] = (stbi_uc)c1; + rgbe[1] = (stbi_uc)c2; + rgbe[2] = (stbi_uc)len; + rgbe[3] = (stbi_uc)stbi__get8(s); + stbi__hdr_convert(hdr_data, rgbe, req_comp); + i = 1; + j = 0; + STBI_FREE(scanline); + goto main_decode_loop; // yes, this makes no sense + } + len <<= 8; + len |= stbi__get8(s); + if (len != width) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); + } + if (scanline == NULL) { + scanline = (stbi_uc *)stbi__malloc_mad2(width, 4, 0); + if (!scanline) { + STBI_FREE(hdr_data); + return stbi__errpf("outofmem", "Out of memory"); + } + } + + for (k = 0; k < 4; ++k) { + int nleft; + i = 0; + while ((nleft = width - i) > 0) { + count = stbi__get8(s); + if (count > 128) { + // Run + value = stbi__get8(s); + count -= 128; + if ((count == 0) || (count > nleft)) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("corrupt", "bad RLE data in HDR"); + } + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = value; + } else { + // Dump + if ((count == 0) || (count > nleft)) { + STBI_FREE(hdr_data); + STBI_FREE(scanline); + return stbi__errpf("corrupt", "bad RLE data in HDR"); + } + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = stbi__get8(s); + } + } + } + for (i = 0; i < width; ++i) + stbi__hdr_convert(hdr_data + (j * width + i) * req_comp, scanline + i * 4, req_comp); + } + if (scanline) + STBI_FREE(scanline); + } + + return hdr_data; +} + +static int stbi__hdr_info(stbi__context * s, int * x, int * y, int * comp) { + char buffer[STBI__HDR_BUFLEN]; + char * token; + int valid = 0; + int dummy; + + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + + if (stbi__hdr_test(s) == 0) { + stbi__rewind(s); + return 0; + } + + for (;;) { + token = stbi__hdr_gettoken(s, buffer); + if (token[0] == 0) + break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) + valid = 1; + } + + if (!valid) { + stbi__rewind(s); + return 0; + } + token = stbi__hdr_gettoken(s, buffer); + if (strncmp(token, "-Y ", 3)) { + stbi__rewind(s); + return 0; + } + token += 3; + *y = (int)strtol(token, &token, 10); + while (*token == ' ') + ++token; + if (strncmp(token, "+X ", 3)) { + stbi__rewind(s); + return 0; + } + token += 3; + *x = (int)strtol(token, NULL, 10); + *comp = 3; + return 1; +} +#endif // STBI_NO_HDR + +#ifndef STBI_NO_BMP +static int stbi__bmp_info(stbi__context * s, int * x, int * y, int * comp) { + void * p; + stbi__bmp_data info; + + info.all_a = 255; + p = stbi__bmp_parse_header(s, &info); + if (p == NULL) { + stbi__rewind(s); + return 0; + } + if (x) + *x = s->img_x; + if (y) + *y = s->img_y; + if (comp) { + if (info.bpp == 24 && info.ma == 0xff000000) + *comp = 3; + else + *comp = info.ma ? 4 : 3; + } + return 1; +} +#endif + +#ifndef STBI_NO_PSD +static int stbi__psd_info(stbi__context * s, int * x, int * y, int * comp) { + int channelCount, dummy, depth; + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind(s); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind(s); + return 0; + } + *y = stbi__get32be(s); + *x = stbi__get32be(s); + depth = stbi__get16be(s); + if (depth != 8 && depth != 16) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 3) { + stbi__rewind(s); + return 0; + } + *comp = 4; + return 1; +} + +static int stbi__psd_is16(stbi__context * s) { + int channelCount, depth; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind(s); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind(s); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind(s); + return 0; + } + STBI_NOTUSED(stbi__get32be(s)); + STBI_NOTUSED(stbi__get32be(s)); + depth = stbi__get16be(s); + if (depth != 16) { + stbi__rewind(s); + return 0; + } + return 1; +} +#endif + +#ifndef STBI_NO_PIC +static int stbi__pic_info(stbi__context * s, int * x, int * y, int * comp) { + int act_comp = 0, num_packets = 0, chained, dummy; + stbi__pic_packet packets[10]; + + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + + if (!stbi__pic_is4(s, "\x53\x80\xF6\x34")) { + stbi__rewind(s); + return 0; + } + + stbi__skip(s, 88); + + *x = stbi__get16be(s); + *y = stbi__get16be(s); + if (stbi__at_eof(s)) { + stbi__rewind(s); + return 0; + } + if ((*x) != 0 && (1 << 28) / (*x) < (*y)) { + stbi__rewind(s); + return 0; + } + + stbi__skip(s, 8); + + do { + stbi__pic_packet * packet; + + if (num_packets == sizeof(packets) / sizeof(packets[0])) + return 0; + + packet = &packets[num_packets++]; + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); + act_comp |= packet->channel; + + if (stbi__at_eof(s)) { + stbi__rewind(s); + return 0; + } + if (packet->size != 8) { + stbi__rewind(s); + return 0; + } + } while (chained); + + *comp = (act_comp & 0x10 ? 4 : 3); + + return 1; +} +#endif + +// ************************************************************************************************* +// Portable Gray Map and Portable Pixel Map loader +// by Ken Miller +// +// PGM: http://netpbm.sourceforge.net/doc/pgm.html +// PPM: http://netpbm.sourceforge.net/doc/ppm.html +// +// Known limitations: +// Does not support comments in the header section +// Does not support ASCII image data (formats P2 and P3) + +#ifndef STBI_NO_PNM + +static int stbi__pnm_test(stbi__context * s) { + char p, t; + p = (char)stbi__get8(s); + t = (char)stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } + return 1; +} + +static void * stbi__pnm_load(stbi__context * s, int * x, int * y, int * comp, int req_comp, stbi__result_info * ri) { + stbi_uc * out; + STBI_NOTUSED(ri); + + ri->bits_per_channel = stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n); + if (ri->bits_per_channel == 0) + return 0; + + if (s->img_y > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) + return stbi__errpuc("too large", "Very large image (corrupt?)"); + + *x = s->img_x; + *y = s->img_y; + if (comp) + *comp = s->img_n; + + if (!stbi__mad4sizes_valid(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0)) + return stbi__errpuc("too large", "PNM too large"); + + out = (stbi_uc *)stbi__malloc_mad4(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0); + if (!out) + return stbi__errpuc("outofmem", "Out of memory"); + if (!stbi__getn(s, out, s->img_n * s->img_x * s->img_y * (ri->bits_per_channel / 8))) { + STBI_FREE(out); + return stbi__errpuc("bad PNM", "PNM file truncated"); + } + + if (req_comp && req_comp != s->img_n) { + if (ri->bits_per_channel == 16) { + out = (stbi_uc *)stbi__convert_format16((stbi__uint16 *)out, s->img_n, req_comp, s->img_x, s->img_y); + } else { + out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); + } + if (out == NULL) + return out; // stbi__convert_format frees input on failure + } + return out; +} + +static int stbi__pnm_isspace(char c) { return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; } + +static void stbi__pnm_skip_whitespace(stbi__context * s, char * c) { + for (;;) { + while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) + *c = (char)stbi__get8(s); + + if (stbi__at_eof(s) || *c != '#') + break; + + while (!stbi__at_eof(s) && *c != '\n' && *c != '\r') + *c = (char)stbi__get8(s); + } +} + +static int stbi__pnm_isdigit(char c) { return c >= '0' && c <= '9'; } + +static int stbi__pnm_getinteger(stbi__context * s, char * c) { + int value = 0; + + while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { + value = value * 10 + (*c - '0'); + *c = (char)stbi__get8(s); + if ((value > 214748364) || (value == 214748364 && *c > '7')) + return stbi__err("integer parse overflow", "Parsing an integer in the PPM header overflowed a 32-bit int"); + } + + return value; +} + +static int stbi__pnm_info(stbi__context * s, int * x, int * y, int * comp) { + int maxv, dummy; + char c, p, t; + + if (!x) + x = &dummy; + if (!y) + y = &dummy; + if (!comp) + comp = &dummy; + + stbi__rewind(s); + + // Get identifier + p = (char)stbi__get8(s); + t = (char)stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } + + *comp = (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm + + c = (char)stbi__get8(s); + stbi__pnm_skip_whitespace(s, &c); + + *x = stbi__pnm_getinteger(s, &c); // read width + if (*x == 0) + return stbi__err("invalid width", "PPM image header had zero or overflowing width"); + stbi__pnm_skip_whitespace(s, &c); + + *y = stbi__pnm_getinteger(s, &c); // read height + if (*y == 0) + return stbi__err("invalid width", "PPM image header had zero or overflowing width"); + stbi__pnm_skip_whitespace(s, &c); + + maxv = stbi__pnm_getinteger(s, &c); // read max value + if (maxv > 65535) + return stbi__err("max value > 65535", "PPM image supports only 8-bit and 16-bit images"); + else if (maxv > 255) + return 16; + else + return 8; +} + +static int stbi__pnm_is16(stbi__context * s) { + if (stbi__pnm_info(s, NULL, NULL, NULL) == 16) + return 1; + return 0; +} +#endif + +static int stbi__info_main(stbi__context * s, int * x, int * y, int * comp) { +#ifndef STBI_NO_JPEG + if (stbi__jpeg_info(s, x, y, comp)) + return 1; +#endif + +#ifndef STBI_NO_PNG + if (stbi__png_info(s, x, y, comp)) + return 1; +#endif + +#ifndef STBI_NO_GIF + if (stbi__gif_info(s, x, y, comp)) + return 1; +#endif + +#ifndef STBI_NO_BMP + if (stbi__bmp_info(s, x, y, comp)) + return 1; +#endif + +#ifndef STBI_NO_PSD + if (stbi__psd_info(s, x, y, comp)) + return 1; +#endif + +#ifndef STBI_NO_PIC + if (stbi__pic_info(s, x, y, comp)) + return 1; +#endif + +#ifndef STBI_NO_PNM + if (stbi__pnm_info(s, x, y, comp)) + return 1; +#endif + +#ifndef STBI_NO_HDR + if (stbi__hdr_info(s, x, y, comp)) + return 1; +#endif + +// test tga last because it's a crappy test! +#ifndef STBI_NO_TGA + if (stbi__tga_info(s, x, y, comp)) + return 1; +#endif + return stbi__err("unknown image type", "Image not of any known type, or corrupt"); +} + +static int stbi__is_16_main(stbi__context * s) { +#ifndef STBI_NO_PNG + if (stbi__png_is16(s)) + return 1; +#endif + +#ifndef STBI_NO_PSD + if (stbi__psd_is16(s)) + return 1; +#endif + +#ifndef STBI_NO_PNM + if (stbi__pnm_is16(s)) + return 1; +#endif + return 0; +} + +#ifndef STBI_NO_STDIO +STBIDEF int stbi_info(char const * filename, int * x, int * y, int * comp) { + FILE * f = stbi__fopen(filename, "rb"); + int result; + if (!f) + return stbi__err("can't fopen", "Unable to open file"); + result = stbi_info_from_file(f, x, y, comp); + fclose(f); + return result; +} + +STBIDEF int stbi_info_from_file(FILE * f, int * x, int * y, int * comp) { + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__info_main(&s, x, y, comp); + fseek(f, pos, SEEK_SET); + return r; +} + +STBIDEF int stbi_is_16_bit(char const * filename) { + FILE * f = stbi__fopen(filename, "rb"); + int result; + if (!f) + return stbi__err("can't fopen", "Unable to open file"); + result = stbi_is_16_bit_from_file(f); + fclose(f); + return result; +} + +STBIDEF int stbi_is_16_bit_from_file(FILE * f) { + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__is_16_main(&s); + fseek(f, pos, SEEK_SET); + return r; +} +#endif // !STBI_NO_STDIO + +STBIDEF int stbi_info_from_memory(stbi_uc const * buffer, int len, int * x, int * y, int * comp) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__info_main(&s, x, y, comp); +} + +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const * c, void * user, int * x, int * y, int * comp) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); + return stbi__info_main(&s, x, y, comp); +} + +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const * buffer, int len) { + stbi__context s; + stbi__start_mem(&s, buffer, len); + return stbi__is_16_main(&s); +} + +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const * c, void * user) { + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)c, user); + return stbi__is_16_main(&s); +} + +#endif // STB_IMAGE_IMPLEMENTATION + +/* + revision history: + 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs + 2.19 (2018-02-11) fix warning + 2.18 (2018-01-30) fix warnings + 2.17 (2018-01-29) change sbti__shiftsigned to avoid clang -O2 bug + 1-bit BMP + *_is_16_bit api + avoid warnings + 2.16 (2017-07-23) all functions have 16-bit variants; + STBI_NO_STDIO works again; + compilation fixes; + fix rounding in unpremultiply; + optimize vertical flip; + disable raw_len validation; + documentation fixes + 2.15 (2017-03-18) fix png-1,2,4 bug; now all Imagenet JPGs decode; + warning fixes; disable run-time SSE detection on gcc; + uniform handling of optional "return" values; + thread-safe initialization of zlib tables + 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs + 2.13 (2016-11-29) add 16-bit API, only supported for PNG right now + 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes + 2.11 (2016-04-02) allocate large structures on the stack + remove white matting for transparent PSD + fix reported channel count for PNG & BMP + re-enable SSE2 in non-gcc 64-bit + support RGB-formatted JPEG + read 16-bit PNGs (only as 8-bit) + 2.10 (2016-01-22) avoid warning introduced in 2.09 by STBI_REALLOC_SIZED + 2.09 (2016-01-16) allow comments in PNM files + 16-bit-per-pixel TGA (not bit-per-component) + info() for TGA could break due to .hdr handling + info() for BMP to shares code instead of sloppy parse + can use STBI_REALLOC_SIZED if allocator doesn't support realloc + code cleanup + 2.08 (2015-09-13) fix to 2.07 cleanup, reading RGB PSD as RGBA + 2.07 (2015-09-13) fix compiler warnings + partial animated GIF support + limited 16-bpc PSD support + #ifdef unused functions + bug with < 92 byte PIC,PNM,HDR,TGA + 2.06 (2015-04-19) fix bug where PSD returns wrong '*comp' value + 2.05 (2015-04-19) fix bug in progressive JPEG handling, fix warning + 2.04 (2015-04-15) try to re-enable SIMD on MinGW 64-bit + 2.03 (2015-04-12) extra corruption checking (mmozeiko) + stbi_set_flip_vertically_on_load (nguillemot) + fix NEON support; fix mingw support + 2.02 (2015-01-19) fix incorrect assert, fix warning + 2.01 (2015-01-17) fix various warnings; suppress SIMD on gcc 32-bit without -msse2 + 2.00b (2014-12-25) fix STBI_MALLOC in progressive JPEG + 2.00 (2014-12-25) optimize JPG, including x86 SSE2 & NEON SIMD (ryg) + progressive JPEG (stb) + PGM/PPM support (Ken Miller) + STBI_MALLOC,STBI_REALLOC,STBI_FREE + GIF bugfix -- seemingly never worked + STBI_NO_*, STBI_ONLY_* + 1.48 (2014-12-14) fix incorrectly-named assert() + 1.47 (2014-12-14) 1/2/4-bit PNG support, both direct and paletted (Omar Cornut & stb) + optimize PNG (ryg) + fix bug in interlaced PNG with user-specified channel count (stb) + 1.46 (2014-08-26) + fix broken tRNS chunk (colorkey-style transparency) in non-paletted PNG + 1.45 (2014-08-16) + fix MSVC-ARM internal compiler error by wrapping malloc + 1.44 (2014-08-07) + various warning fixes from Ronny Chevalier + 1.43 (2014-07-15) + fix MSVC-only compiler problem in code changed in 1.42 + 1.42 (2014-07-09) + don't define _CRT_SECURE_NO_WARNINGS (affects user code) + fixes to stbi__cleanup_jpeg path + added STBI_ASSERT to avoid requiring assert.h + 1.41 (2014-06-25) + fix search&replace from 1.36 that messed up comments/error messages + 1.40 (2014-06-22) + fix gcc struct-initialization warning + 1.39 (2014-06-15) + fix to TGA optimization when req_comp != number of components in TGA; + fix to GIF loading because BMP wasn't rewinding (whoops, no GIFs in my test suite) + add support for BMP version 5 (more ignored fields) + 1.38 (2014-06-06) + suppress MSVC warnings on integer casts truncating values + fix accidental rename of 'skip' field of I/O + 1.37 (2014-06-04) + remove duplicate typedef + 1.36 (2014-06-03) + convert to header file single-file library + if de-iphone isn't set, load iphone images color-swapped instead of returning NULL + 1.35 (2014-05-27) + various warnings + fix broken STBI_SIMD path + fix bug where stbi_load_from_file no longer left file pointer in correct place + fix broken non-easy path for 32-bit BMP (possibly never used) + TGA optimization by Arseny Kapoulkine + 1.34 (unknown) + use STBI_NOTUSED in stbi__resample_row_generic(), fix one more leak in tga failure case + 1.33 (2011-07-14) + make stbi_is_hdr work in STBI_NO_HDR (as specified), minor compiler-friendly improvements + 1.32 (2011-07-13) + support for "info" function for all supported filetypes (SpartanJ) + 1.31 (2011-06-20) + a few more leak fixes, bug in PNG handling (SpartanJ) + 1.30 (2011-06-11) + added ability to load files via callbacks to accomidate custom input streams (Ben Wenger) + removed deprecated format-specific test/load functions + removed support for installable file formats (stbi_loader) -- would have been broken for IO callbacks + anyway error cases in bmp and tga give messages and don't leak (Raymond Barbiero, grisha) fix inefficiency in + decoding 32-bit BMP (David Woo) 1.29 (2010-08-16) various warning fixes from Aurelien Pocheville 1.28 (2010-08-01) + fix bug in GIF palette transparency (SpartanJ) + 1.27 (2010-08-01) + cast-to-stbi_uc to fix warnings + 1.26 (2010-07-24) + fix bug in file buffering for PNG reported by SpartanJ + 1.25 (2010-07-17) + refix trans_data warning (Won Chun) + 1.24 (2010-07-12) + perf improvements reading from files on platforms with lock-heavy fgetc() + minor perf improvements for jpeg + deprecated type-specific functions so we'll get feedback if they're needed + attempt to fix trans_data warning (Won Chun) + 1.23 fixed bug in iPhone support + 1.22 (2010-07-10) + removed image *writing* support + stbi_info support from Jetro Lauha + GIF support from Jean-Marc Lienher + iPhone PNG-extensions from James Brown + warning-fixes from Nicolas Schulz and Janez Zemva (i.stbi__err. Janez (U+017D)emva) + 1.21 fix use of 'stbi_uc' in header (reported by jon blow) + 1.20 added support for Softimage PIC, by Tom Seddon + 1.19 bug in interlaced PNG corruption check (found by ryg) + 1.18 (2008-08-02) + fix a threading bug (local mutable static) + 1.17 support interlaced PNG + 1.16 major bugfix - stbi__convert_format converted one too many pixels + 1.15 initialize some fields for thread safety + 1.14 fix threadsafe conversion bug + header-file-only version (#define STBI_HEADER_FILE_ONLY before including) + 1.13 threadsafe + 1.12 const qualifiers in the API + 1.11 Support installable IDCT, colorspace conversion routines + 1.10 Fixes for 64-bit (don't use "unsigned long") + optimized upsampling by Fabian "ryg" Giesen + 1.09 Fix format-conversion for PSD code (bad global variables!) + 1.08 Thatcher Ulrich's PSD code integrated by Nicolas Schulz + 1.07 attempt to fix C++ warning/errors again + 1.06 attempt to fix C++ warning/errors again + 1.05 fix TGA loading to return correct *comp and use good luminance calc + 1.04 default float alpha is 1, not 255; use 'void *' for stbi_image_free + 1.03 bugfixes to STBI_NO_STDIO, STBI_NO_HDR + 1.02 support for (subset of) HDR files, float interface for preferred access to them + 1.01 fix bug: possible bug in handling right-side up bmps... not sure + fix bug: the stbi__bmp_load() and stbi__tga_load() functions didn't work at all + 1.00 interface to zlib that skips zlib header + 0.99 correct handling of alpha in palette + 0.98 TGA loader by lonesock; dynamically add loaders (untested) + 0.97 jpeg errors on too large a file; also catch another malloc failure + 0.96 fix detection of invalid v value - particleman@mollyrocket forum + 0.95 during header scan, seek to markers in case of padding + 0.94 STBI_NO_STDIO to disable stdio usage; rename all #defines the same + 0.93 handle jpegtran output; verbose errors + 0.92 read 4,8,16,24,32-bit BMP files of several formats + 0.91 output 24-bit Windows 3.0 BMP files + 0.90 fix a few more warnings; bump version number to approach 1.0 + 0.61 bugfixes due to Marc LeBlanc, Christopher Lloyd + 0.60 fix compiling as c++ + 0.59 fix warnings: merge Dave Moore's -Wall fixes + 0.58 fix bug: zlib uncompressed mode len/nlen was wrong endian + 0.57 fix bug: jpg last huffman symbol before marker was >9 bits but less than 16 available + 0.56 fix bug: zlib uncompressed mode len vs. nlen + 0.55 fix bug: restart_interval not initialized to 0 + 0.54 allow NULL for 'int *comp' + 0.53 fix bug in png 3->4; speedup png decoding + 0.52 png handles req_comp=3,4 directly; minor cleanup; jpeg comments + 0.51 obey req_comp requests, 1-component jpegs return as 1-component, + on 'test' only check type, not whether we support this variant + 0.50 (2006-11-19) + first released version +*/ + +/* +------------------------------------------------------------------------------ +This software is available under 2 licenses -- choose whichever you prefer. +------------------------------------------------------------------------------ +ALTERNATIVE A - MIT License +Copyright (c) 2017 Sean Barrett +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +------------------------------------------------------------------------------ +ALTERNATIVE B - Public Domain (www.unlicense.org) +This is free and unencumbered software released into the public domain. +Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +software, either in source code form or as a compiled binary, for any purpose, +commercial or non-commercial, and by any means. +In jurisdictions that recognize copyright laws, the author or authors of this +software dedicate any and all copyright interest in the software to the public +domain. We make this dedication for the benefit of the public at large and to +the detriment of our heirs and successors. We intend this dedication to be an +overt act of relinquishment in perpetuity of all present and future rights to +this software under copyright law. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +------------------------------------------------------------------------------ +*/ diff --git a/applications/src/llama/common/train.cpp b/applications/src/llama/common/train.cpp new file mode 100644 index 000000000..dcf9614e4 --- /dev/null +++ b/applications/src/llama/common/train.cpp @@ -0,0 +1,1513 @@ +#include "train.h" +#include "common.h" + +#include +#include +#include + +struct random_normal_distribution { + std::mt19937 gen; + std::normal_distribution rd; + float min; + float max; +}; + +struct random_uniform_distribution { + std::mt19937 gen; + std::uniform_real_distribution rd; +}; + +struct train_state * init_train_state() { + struct train_state * state = new struct train_state; + state->train_its = 0; + state->train_samples = 0; + state->train_tokens = 0; + state->train_epochs = 0; + state->shuffle_samples_hash = 0; + state->shuffle_sample_count = 0; + state->shuffle_next_sample = 0; + state->shuffle_rng_state_current = ""; + state->shuffle_rng_state_next = ""; + + state->opt = new struct ggml_opt_context; + state->opt->ctx = NULL; + state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM); + state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES; + state->opt->loss_after = 0.0f; + + return state; +} + +void free_train_state(struct train_state * state) { + delete state->opt; + delete state; +} + +struct random_normal_distribution * init_random_normal_distribution( + int seed, float mean, float std, float min, float max +) { + struct random_normal_distribution * rnd = (struct random_normal_distribution *) malloc(sizeof(struct random_normal_distribution)); + rnd->gen = std::mt19937(seed); + rnd->rd = std::normal_distribution{mean, std}; + rnd->min = min; + rnd->max = max; + return rnd; +} + +struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max) { + struct random_uniform_distribution * rnd = (struct random_uniform_distribution *) malloc(sizeof(struct random_uniform_distribution)); + rnd->gen = std::mt19937(seed); + rnd->rd = std::uniform_real_distribution{min, max}; + return rnd; +} + +void free_random_normal_distribution (struct random_normal_distribution * rnd) { + free(rnd); +} + +void free_random_uniform_distribution(struct random_uniform_distribution * rnd) { + free(rnd); +} + +struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) { + float scale = 1.0f; // xavier + switch (ggml_n_dims(tensor)) { + case 1: + scale /= sqrtf((float) tensor->ne[0]); + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]); + *dst = scale * frand_normal(rnd); + } + break; + case 2: + scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]); + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); + *dst = scale * frand_normal(rnd); + } + } + break; + case 3: + scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]); + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]); + *dst = scale * frand_normal(rnd); + } + } + } + break; + case 4: + scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]); + for (int i3 = 0; i3 < tensor->ne[3]; i3++) { + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]); + *dst = scale * frand_normal(rnd); + } + } + } + } + break; + default: + die("Unsupported tensor->n_dims"); + }; + return tensor; +} + +struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd) { + switch (ggml_n_dims(tensor)) { + case 1: + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]); + *dst = frand_uniform(rnd); + } + break; + case 2: + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); + *dst = frand_uniform(rnd); + } + } + break; + case 3: + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]); + *dst = frand_uniform(rnd); + } + } + } + break; + case 4: + for (int i3 = 0; i3 < tensor->ne[3]; i3++) { + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]); + *dst = frand_uniform(rnd); + } + } + } + } + break; + default: + die("Unsupported tensor->n_dims"); + }; + return tensor; +} + +float frand() { + return (float)rand()/((float)(RAND_MAX) + 1.0f); +} + +float frand_normal(struct random_normal_distribution * rnd) { + return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max); +} + +float frand_uniform(struct random_uniform_distribution * rnd) { + return rnd->rd(rnd->gen); +} + +int clamp(const int v, const int min, const int max) { + return ((v < min) ? (min) : (v > max) ? (max) : v); +} + +float fclamp(const float v, const float min, const float max) { + return ((v < min) ? (min) : (v > max) ? (max) : v); +} + +void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) { + GGML_ASSERT(tensor->ne[0] == ne0); + GGML_ASSERT(tensor->ne[1] == 1); + GGML_ASSERT(tensor->ne[2] == 1); + GGML_ASSERT(tensor->ne[3] == 1); +} + +void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) { + GGML_ASSERT(tensor->ne[0] == ne0); + GGML_ASSERT(tensor->ne[1] == ne1); + GGML_ASSERT(tensor->ne[2] == 1); + GGML_ASSERT(tensor->ne[3] == 1); +} + +void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) { + GGML_ASSERT(tensor->ne[0] == ne0); + GGML_ASSERT(tensor->ne[1] == ne1); + GGML_ASSERT(tensor->ne[2] == ne2); + GGML_ASSERT(tensor->ne[3] == 1); +} + +void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) { + GGML_ASSERT(tensor->ne[0] == ne0); + GGML_ASSERT(tensor->ne[1] == ne1); + GGML_ASSERT(tensor->ne[2] == ne2); + GGML_ASSERT(tensor->ne[3] == ne3); +} + +int64_t get_example_targets_batch( + struct llama_context * lctx, + struct ggml_tensor * tokens_input, + struct ggml_tensor * target_probs, + int64_t example_id, + const size_t * samples_offs, + const size_t * samples_begin, + const size_t * samples_size, + size_t samples_count, + const llama_token * train_data, + size_t n_train_data, + bool separate_with_eos, + bool separate_with_bos, + bool fill_with_next_samples, + bool sample_random_offsets +) { + GGML_ASSERT(samples_count > 0); + GGML_ASSERT(ggml_is_matrix(tokens_input)); + GGML_ASSERT(ggml_is_3d(target_probs)); + int64_t n_vocab = target_probs->ne[0]; + int64_t n_tokens = tokens_input->ne[0]; + int64_t n_batch = tokens_input->ne[1]; + GGML_ASSERT(n_vocab == target_probs->ne[0]); + GGML_ASSERT(n_tokens == target_probs->ne[1]); + GGML_ASSERT(n_batch == target_probs->ne[2]); + + int64_t used_samples = 0; + + ggml_set_f32(target_probs, 0.0f); + llama_token bos = llama_token_bos(llama_get_model(lctx)); + llama_token eos = llama_token_eos(llama_get_model(lctx)); + // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples); + for (int k=0; k= sample_size && fill_with_next_samples) { + if (!sample_separation_eos) { + // insert eos token to separate samples + sample_separation_eos = true; + } else if (!sample_separation_bos) { + // insert bos token to separate samples + sample_separation_bos = true; + token = bos; + } else { + // sample separation is done, continue with next sample + sample_separation_eos = !separate_with_eos; + sample_separation_bos = !separate_with_bos; + sample_offs = 0; + sample_idx = (example_id + used_samples) % samples_count; + sample_begin = samples_begin[sample_idx]; + sample_size = samples_size[sample_idx]; + ++used_samples; + } + } + // note: no else-if here + if (sample_offs < sample_size) { + token = clamp(train_data[sample_begin+sample_offs], 0, (llama_token) (n_vocab - 1)); + ++sample_offs; + } + ggml_set_f32_nd(target_probs, token, (int) i, (int) k, 0, +1.0f); + if (i+1> rng; +} + +std::string mt19937_get_state(const std::mt19937& rng) { + std::stringstream s_rng_state; + s_rng_state.imbue(std::locale::classic()); + s_rng_state << rng; + return s_rng_state.str(); +} + +std::string mt19937_seed_to_state(unsigned seed) { + std::mt19937 rng(seed); + return mt19937_get_state(rng); +} + +std::string shuffle_samples( + const std::string & rng_state, + size_t * shuffled_offs, + size_t * shuffled_begins, + size_t * shuffled_sizes, + const size_t * begins, + const size_t * sizes, + size_t count) { + if (count == 0) return rng_state; + + std::mt19937 rng; + mt19937_set_state(rng, rng_state); + + // sort indices by random value for each index + std::vector idcs; + { + std::vector rnd; + idcs.resize(count); + rnd.resize(count); + for (unsigned i=0; i h_string; + std::hash h_ull; + size_t h = h_string(std::string(fn)); + h = hash_combine(h, h_ull((unsigned long long) sample_count)); + for (size_t i=0; i< sample_count; ++i) { + h = hash_combine(h, h_ull((unsigned long long) samples_begin[i])); + h = hash_combine(h, h_ull((unsigned long long) samples_size[i])); + } + return h; +} + +std::string replace_str(const char * s, const char * needle, const char * replacement) { + std::string str = s; + size_t pos = str.find(needle); + if (pos != std::string::npos) { + str.replace(pos, strlen(needle), replacement); + } + return str; +} + +void print_duration(double fmillis) { + if (fmillis < 1000.0f) { + printf("%.1fms", (float) fmillis); + return; + } + const int64_t one_sec = 1000; + const int64_t one_min = one_sec * 60; + const int64_t one_hour = one_min * 60; + const int64_t one_day = one_hour * 24; + + int64_t millis = (int64_t) fmillis; + int64_t days = millis/one_day; + int64_t hours = (millis - days*one_day)/one_hour; + int64_t minutes = (millis - days*one_day - hours*one_hour)/one_min; + int64_t seconds = (millis - days*one_day - hours*one_hour - minutes*one_min)/one_sec; + + // to print int64_t either cast to (long long int) or use macro PRId64 from + if (days > 0) { + printf("%lldd ", (long long int) days); + } + printf("%02lld:%02lld:%02lld", (long long int) hours, (long long int) minutes, (long long int) seconds); +} + +float cosine_decay(int64_t step, int64_t decay_steps, float minimum) { + if (step > decay_steps) { + step = decay_steps; + } + const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps)); + const float decay = (1 - minimum)*cosine_decay + minimum; + return decay; +} + +float cosine_decay_restart(int64_t step, int64_t decay_steps, float minimum, float restart_step_mult) { + while (step > decay_steps) { + step -= decay_steps; + decay_steps = (int64_t) (restart_step_mult * decay_steps); + } + return cosine_decay(step, decay_steps, minimum); +} + +float learning_schedule( + int64_t step, + int64_t warmup_steps, + int64_t cos_decay_steps, + float learning_rate, + float overall_minimum, + float cos_decay_minimum, + float cos_decay_restart_step_mult, + bool enable_restart) { + + float result = + (step < warmup_steps) + ? (float) step / (float) warmup_steps + : enable_restart + ? cosine_decay_restart( + step - warmup_steps, + cos_decay_steps, + cos_decay_minimum, + cos_decay_restart_step_mult) + : cosine_decay( + step, + cos_decay_steps, + cos_decay_minimum); + + float min = overall_minimum / learning_rate; + result = min + result * (1.0f - min); + return result; +} + +static bool are_same_layout(struct ggml_tensor * a, struct ggml_tensor * b) { + GGML_ASSERT(a != NULL); + GGML_ASSERT(b != NULL); + GGML_ASSERT(a->type == b->type); + GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_is_contiguous(a) && ggml_is_contiguous(b)); + + return true; +} + +void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name) { + if (dst == NULL) { + return; + } + struct ggml_tensor * t = ggml_get_tensor(ctx, name); + GGML_ASSERT(are_same_layout(dst, t)); + memcpy(dst->data, t->data, ggml_nbytes(t)); + + if (strlen(ggml_get_name(dst)) == 0) { + ggml_set_name(dst, name); + } +} + +// gguf constants +static const char * LLM_KV_OPTIMIZER_TYPE = "optimizer.type"; +static const char * LLM_KV_OPTIMIZER_TYPE_ADAM = "adam"; +static const char * LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs"; +static const char * LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version"; +static const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count"; +static const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count"; +static const char * LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count"; +static const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized"; +static const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss"; +static const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss"; +static const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count"; +static const char * LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count"; +static const char * LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss"; +static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step"; +static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j"; +static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k"; +static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end"; +static const char * LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count"; + +static const char * LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments"; +static const char * LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments"; +static const char * LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values"; + +static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters"; +static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters"; +static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients"; +static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients"; +static const char * LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction"; +static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values"; +static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha"; +static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys"; +static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s"; +static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y"; + +static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version"; +static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"; +static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"; +static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"; +static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count"; +static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash"; +static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state"; +static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count"; +static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE = "training.shuffle.next_sample"; + +#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \ +{ \ + const std::string skey(key); \ + const int kid = gguf_find_key(ctx, skey.c_str()); \ + if (kid >= 0) { \ + enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \ + if (ktype != (type)) { \ + die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \ + } \ + (dst) = func(ctx, kid); \ + } else if (req) { \ + die_fmt("key not found in model: %s", skey.c_str()); \ + } \ +} + +void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) { + // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read + + uint32_t file_version; + GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_FILE_VERSION); + GGML_ASSERT(file_version == 0); + + GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT); + GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT); + GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED); + + uint64_t nx; + GGUF_GET_KEY(fctx, nx, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_OPTIMIZER_PARAMETER_COUNT); + opt->nx = (size_t) nx; + + // don't call ggml_opt_init until optimizer type and optimizer specific parameters are know + + std::string opt_type; + GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE); + if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) { + opt->params.type = GGML_OPT_ADAM; + + GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS); + GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS); + GGUF_GET_KEY(fctx, opt->adam.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT); + + ggml_opt_init(opt->ctx, opt, opt->params, opt->nx); + + copy_tensor_by_name(opt->adam.m, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS); + copy_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS); + copy_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES); + } else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) { + opt->params.type = GGML_OPT_LBFGS; + + GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT); + GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS); + GGUF_GET_KEY(fctx, opt->lbfgs.step, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP); + GGUF_GET_KEY(fctx, opt->lbfgs.j, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J); + GGUF_GET_KEY(fctx, opt->lbfgs.k, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K); + GGUF_GET_KEY(fctx, opt->lbfgs.end, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END); + GGUF_GET_KEY(fctx, opt->lbfgs.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT); + + ggml_opt_init(opt->ctx, opt, opt->params, opt->nx); + + copy_tensor_by_name(opt->lbfgs.x, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS); + copy_tensor_by_name(opt->lbfgs.xp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS); + copy_tensor_by_name(opt->lbfgs.g, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS); + copy_tensor_by_name(opt->lbfgs.gp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS); + copy_tensor_by_name(opt->lbfgs.d, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION); + copy_tensor_by_name(opt->lbfgs.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES); + copy_tensor_by_name(opt->lbfgs.lmal, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA); + copy_tensor_by_name(opt->lbfgs.lmys, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS); + copy_tensor_by_name(opt->lbfgs.lms, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S); + copy_tensor_by_name(opt->lbfgs.lmy, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y); + } else { + die("unknown optimizer type\n"); + } +} + +void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt) { + gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_FILE_VERSION, 0); + gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, opt->params.past); + gguf_set_val_u64(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT, (uint64_t) opt->nx); + gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter); + gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized); + + switch (opt->params.type) { + case GGML_OPT_ADAM: + { + gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM); + gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best); + gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS, opt->adam.fx_prev); + gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, opt->adam.n_no_improvement); + + ggml_set_name(opt->adam.m, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS); + ggml_set_name(opt->adam.v, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS); + if (opt->adam.pf) { + ggml_set_name(opt->adam.pf, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES); + } + + gguf_add_tensor(fctx, opt->adam.m); + gguf_add_tensor(fctx, opt->adam.v); + if (opt->adam.pf) { + gguf_add_tensor(fctx, opt->adam.pf); + } + } break; + case GGML_OPT_LBFGS: + { + gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS); + gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m); + gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, opt->lbfgs.fx_best); + gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, opt->lbfgs.step); + gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, opt->lbfgs.j); + gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, opt->lbfgs.k); + gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, opt->lbfgs.end); + gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, opt->lbfgs.n_no_improvement); + + ggml_set_name(opt->lbfgs.x, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS); + ggml_set_name(opt->lbfgs.xp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS); + ggml_set_name(opt->lbfgs.g, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS); + ggml_set_name(opt->lbfgs.gp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS); + ggml_set_name(opt->lbfgs.d, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION); + if (opt->lbfgs.pf) { + ggml_set_name(opt->lbfgs.pf, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES); + } + ggml_set_name(opt->lbfgs.lmal, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA); + ggml_set_name(opt->lbfgs.lmys, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS); + ggml_set_name(opt->lbfgs.lms, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S); + ggml_set_name(opt->lbfgs.lmy, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y); + + gguf_add_tensor(fctx, opt->lbfgs.x); + gguf_add_tensor(fctx, opt->lbfgs.xp); + gguf_add_tensor(fctx, opt->lbfgs.g); + gguf_add_tensor(fctx, opt->lbfgs.gp); + gguf_add_tensor(fctx, opt->lbfgs.d); + if (opt->lbfgs.pf) { + gguf_add_tensor(fctx, opt->lbfgs.pf); + } + gguf_add_tensor(fctx, opt->lbfgs.lmal); + gguf_add_tensor(fctx, opt->lbfgs.lmys); + gguf_add_tensor(fctx, opt->lbfgs.lms); + gguf_add_tensor(fctx, opt->lbfgs.lmy); + } break; + } +} + +bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train) { + if (gguf_find_key(fctx, LLM_KV_TRAINING_FILE_VERSION) < 0) { + return false; + } + + uint32_t file_version; + GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION); + GGML_ASSERT(file_version <= 1); + + if (file_version == 0) { + + GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT); + GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT); + GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT); + + } else if (file_version == 1) { + + GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT); + GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT); + GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT); + GGUF_GET_KEY(fctx, train->train_epochs, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT); + + GGUF_GET_KEY(fctx, train->shuffle_samples_hash, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH); + GGUF_GET_KEY(fctx, train->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE); + GGUF_GET_KEY(fctx, train->shuffle_sample_count, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT); + GGUF_GET_KEY(fctx, train->shuffle_next_sample, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE); + } + + load_opt_context_gguf(fctx, f_ggml_ctx, train->opt); + return true; +} + +void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) { + gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, train->train_samples); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, train->train_tokens); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT, train->train_epochs); + + gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) train->shuffle_samples_hash); + gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE, train->shuffle_rng_state_current.c_str()); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) train->shuffle_sample_count); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE, (uint64_t) train->shuffle_next_sample); + + save_opt_context_gguf(fctx, train->opt); +} + + +struct llama_file { + // use FILE * so we don't have to re-open the file to mmap + FILE * fp; + size_t size; + + llama_file(const char * fname, const char * mode) { + fp = std::fopen(fname, mode); + if (fp == NULL) { + size = 0; + } else { + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + } + + size_t tell() const { +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); +#else + long ret = std::ftell(fp); +#endif + GGML_ASSERT(ret != -1); // this really shouldn't fail + return (size_t) ret; + } + + void seek(size_t offset, int whence) { +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, whence); +#else + int ret = std::fseek(fp, (long) offset, whence); +#endif + GGML_ASSERT(ret == 0); // same + } + + void read_raw(void * ptr, size_t size) { + if (size == 0) { + return; + } + errno = 0; + std::size_t ret = std::fread(ptr, size, 1, fp); + if (ferror(fp)) { + die_fmt("read error: %s", strerror(errno)); + } + if (ret != 1) { + die("unexpectedly reached end of file"); + } + } + + std::uint32_t read_u32() { + std::uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + + std::string read_string(std::uint32_t len) { + std::vector chars(len); + read_raw(chars.data(), len); + return std::string(chars.data(), len); + } + + void write_raw(const void * ptr, size_t size) { + if (size == 0) { + return; + } + errno = 0; + size_t ret = std::fwrite(ptr, size, 1, fp); + if (ret != 1) { + die_fmt("write error: %s", strerror(errno)); + } + } + + void write_u32(std::uint32_t val) { + write_raw(&val, sizeof(val)); + } + + ~llama_file() { + if (fp) { + std::fclose(fp); + } + } +}; + +static size_t utf8_len(char src) { + const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t highbits = static_cast(src) >> 4; + return lookup[highbits]; +} + +// mark each byte with its utf8 unit number. +// returns the number of utf8 characters. +// e.g. when bytes == '\x61\xD0\xB0\x62', +// then utf8_units will become [0,0,1,0] +// utf8_nunits will become [1,2,2,1] and 3 is returned. +// bytes where utf8_units is zero, are the begin of an utf8 character. +static size_t mark_utf8_units(const char* bytes, int * utf8_units, int * utf8_nunits, size_t count) { + size_t offs = 0; + size_t count_utf8 = 0; + while(offs < count) { + int len = (int) utf8_len(bytes[offs]); + for (int i=0; i & out_tokens, + std::vector & out_samples_begin, + std::vector & out_samples_size) { + struct llama_file f(filename, "rb"); + + if (f.size == 0) { + out_tokens.clear(); + out_samples_begin.clear(); + out_samples_size.clear(); + printf("%s: warning: empty or not existing training data file '%s'\n", + __func__, filename); + return out_tokens.size(); + } + + // account for possible leading whitespace that will be added by tokenizer + // e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12] + const int n_max_tokens_overhead = 1; + + std::vector buf; + buf.resize(f.size); + + f.read_raw(buf.data(), f.size); + + std::vector utf8_units; + std::vector utf8_nunits; + utf8_units.resize(buf.size()); + utf8_nunits.resize(buf.size()); + mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size()); + + if (sample_start.size() == 0) { + // tokenize all data at once + out_tokens.resize(buf.size() + n_max_tokens_overhead); + + int n_tokens = llama_tokenize( + llama_get_model(lctx), + buf.data(), + (int) buf.size(), + out_tokens.data(), + (int) out_tokens.size(), + false, false); + if (n_tokens < 0) { + out_tokens.resize(-n_tokens); + n_tokens = llama_tokenize( + llama_get_model(lctx), + buf.data(), + (int) buf.size(), + out_tokens.data(), + (int) out_tokens.size(), + false, false); + } + if (n_tokens >= 0) { + out_tokens.resize(n_tokens); + } + + // generate sample starts at all token positions + out_samples_begin.clear(); + out_samples_begin.push_back(0); + out_samples_size.push_back(std::min((size_t) context_length, out_tokens.size())); + size_t end = (out_tokens.size() >= context_length) ? (out_tokens.size() - context_length) : 0; + for (size_t sample_begin = 1; sample_begin < end; ++sample_begin) { + out_samples_begin.push_back(sample_begin); + out_samples_size.push_back(context_length); + } + } else { + // split data into samples and tokenize each sample + std::string data_str(buf.data(), buf.size()); + out_samples_begin.clear(); + out_samples_size.clear(); + out_tokens.clear(); + + // find all positions of pattern sample_start + size_t sample_begin = data_str.find(sample_start, 0); + while (sample_begin != std::string::npos) { + out_samples_begin.push_back(sample_begin); + const size_t search_start = sample_begin + sample_start.size(); + sample_begin = data_str.find(sample_start, search_start); + } + if (out_samples_begin.size() == 0) { + printf("%s: warning: sample start pattern '%s' not found. inserting single sample at data begin\n", + __func__, sample_start.c_str()); + out_samples_begin.push_back(0); + } + + out_samples_size.resize(out_samples_begin.size(), 0); + + std::vector buf_sample; + std::vector tok_sample; + + const size_t sample_begin_offset = (include_sample_start ? 0 : sample_start.size()); + size_t found_too_big_sample = 0; + size_t found_too_small_sample = 0; + size_t found_empty_sample = 0; + size_t found_min_sample_size = SIZE_MAX; + size_t found_max_sample_size = 0; + + size_t max_token_text_size = 0; + int n_vocab = llama_n_vocab(llama_get_model(lctx)); + for (llama_token token=0; token < n_vocab; ++token) { + max_token_text_size = std::max( + max_token_text_size, + strlen(llama_token_get_text(llama_get_model(lctx), token))); + } + + // upper bound of context byte length. + // strings with this byte length should always tokenize to at least context_length tokens. + size_t context_byte_len = max_token_text_size*context_length; + + for (unsigned i=0; i 0) { + // sample end is in the middle of an utf8 character. + // advance sample_end to the begin of the next utf8 character. + sample_end += utf8_nunits[sample_end] - utf8_units[sample_end]; + } + size_t sample_size = sample_end - sample_begin; + if (sample_size == 0) { + ++found_empty_sample; + } + + if (sample_size > 0) { + // llama_tokenize expects zero terminated string, + // copy sample into buffer and zero terminate it. + buf_sample.resize(sample_size); + memcpy(buf_sample.data(), data_str.data() + sample_begin, sample_size); + + // printf("sample: '%s'\n", buf_sample.data()); + + // tokenize the sample + tok_sample.resize(buf_sample.size() + n_max_tokens_overhead); + int n_tokens = llama_tokenize(llama_get_model(lctx), + buf_sample.data(), + (int) buf_sample.size(), + tok_sample.data(), + (int) tok_sample.size(), + false, false); + if (n_tokens < 0) { + tok_sample.resize(-n_tokens); + n_tokens = llama_tokenize(llama_get_model(lctx), + buf_sample.data(), + (int) buf_sample.size(), + tok_sample.data(), + (int) tok_sample.size(), + false, false); + GGML_ASSERT(n_tokens >= 0); + } + GGML_ASSERT(n_tokens <= (int) tok_sample.size()); + + if ((size_t) n_tokens > context_length) { + ++found_too_big_sample; + } else if ((size_t) n_tokens < context_length) { + ++found_too_small_sample; + } + found_max_sample_size = std::max(found_max_sample_size, (size_t) n_tokens); + found_min_sample_size = std::min(found_min_sample_size, (size_t) n_tokens); + + // write out tokens, start and size of sample + // overwrite the string start position with the token start position + out_samples_begin[i] = out_tokens.size(); + out_samples_size[i] = (size_t) n_tokens; + out_tokens.insert(out_tokens.end(), tok_sample.begin(), tok_sample.begin() + n_tokens); + } else { + out_samples_begin[i] = out_tokens.size(); + out_samples_size[i] = 0; + } + + } + if (found_too_big_sample > 0) { + printf("%s: warning: found %zu samples (max length %zu) that exceed context length of %u. samples will be cut off.\n", + __func__, found_too_big_sample, found_max_sample_size, context_length); + } + + if (found_too_small_sample > 0) { + printf("%s: warning: found %zu samples (min length %zu) that are shorter than context length of %u.\n", + __func__, found_too_small_sample, found_min_sample_size, context_length); + } + + if (found_empty_sample) { + printf("%s: warning: found %zu empty samples.\n", + __func__, found_empty_sample); + } + } + printf("%s: total number of samples: %zu\n", + __func__, out_samples_begin.size()); + + GGML_ASSERT(out_samples_begin.size() == out_samples_size.size()); + + return out_tokens.size(); +} + +std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration) { + std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); + return replace_str(filename, pattern_it, sit.c_str()); +} + +struct train_params_common get_default_train_params_common() { + struct train_params_common params; + params.fn_train_data = "shakespeare.txt"; + params.fn_checkpoint_in = "checkpoint.gguf"; + params.fn_checkpoint_out = "checkpoint-ITERATION.gguf"; + params.pattern_fn_it = "ITERATION"; + params.fn_latest = "LATEST"; + + params.print_usage = false; + + params.save_every = 10; + + params.seed = -1; + + params.n_ctx = 128; + params.n_threads = 6; + params.n_batch = 8; + params.n_gradient_accumulation = 1; + params.n_epochs = -1; + params.n_gpu_layers = 0; + + params.custom_n_ctx = false; + + params.use_flash = true; + params.use_checkpointing = true; + + params.sample_start = ""; + params.include_sample_start = false; + params.escape = false; + params.overlapping_samples = false; + params.fill_with_next_samples = false; + params.separate_with_eos = false; + params.separate_with_bos = true; + params.sample_random_offsets = false; + params.force_reshuffle = false; + + params.opt_past = 0; + params.opt_delta = 1e-5f; + params.opt_max_no_improvement = 0; + + params.warmup = 100; + params.cos_decay_steps = 1000; + params.cos_decay_restart = 1.1f; + params.cos_decay_min = 0.1f; + params.enable_restart = false; + + params.adam_n_iter = 256; + params.adam_alpha = 1e-3f; + params.adam_min_alpha = 0; + params.adam_decay = 1e-1f; + params.adam_decay_min_ndim = 2; + params.adam_beta1 = 0.9f; + params.adam_beta2 = 0.999f; + params.adam_gclip = 1.0f; + params.adam_eps_f = 0.0f; + + return params; +} + +void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train_params_common * params) { + // fprintf(stderr, "usage: %s [options]\n", argv[0]); + // fprintf(stderr, "\n"); + // fprintf(stderr, "options:\n"); + // fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data); + fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in); + fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out); + fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it); + fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest); + fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every); + fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n"); + fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx); + fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads); + fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch); + fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation); + fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str()); + fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n"); + fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); + fprintf(stderr, " --overlapping-samples Samples my overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n"); + fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n"); + fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : ""); + fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : ""); + fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : ""); + fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : ""); + fprintf(stderr, " --sample-random-offsets Use samples beginning at random offsets. Together with fill-with-next-samples this may help for training endless text generation.%s\n", params->sample_random_offsets ? " (default)" : ""); + fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n"); + fprintf(stderr, " --no-flash Don't use flash attention \n"); + fprintf(stderr, " --use-flash Use flash attention (default)\n"); + fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n"); + fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n"); + fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup); + fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps); + fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart); + fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min); + fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : ""); + fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : ""); + fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past); + fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta); + fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement); + fprintf(stderr, " --epochs N Maximum number epochs to process. (default %d)\n", params->n_epochs); + fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter); + fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha); + fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha); + fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay); + fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim); + fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1); + fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2); + fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip); + fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f); + fprintf(stderr, " -ngl N, --n-gpu-layers N Number of model layers to offload to GPU (default %d)", params->n_gpu_layers); + fprintf(stderr, "\n"); +} + +bool consume_common_train_arg( + int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param +) { + int& i = *idx; + std::string arg = argv[i]; + const std::string arg_prefix = "--"; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); + } + if (arg == "--train-data") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->fn_train_data = argv[i]; + } else if (arg == "--checkpoint-in") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->fn_checkpoint_in = argv[i]; + } else if (arg == "--checkpoint-out") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->fn_checkpoint_out = argv[i]; + } else if (arg == "--pattern-fn-it") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->pattern_fn_it = argv[i]; + } else if (arg == "--fn-latest") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->fn_latest = argv[i]; + } else if (arg == "--save-every") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->save_every = std::stoi(argv[i]); + } else if (arg == "-s" || arg == "--seed") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->seed = std::stoi(argv[i]); + } else if (arg == "-c" || arg == "--ctx") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->n_ctx = std::stoi(argv[i]); + params->custom_n_ctx = true; + } else if (arg == "-t" || arg == "--threads") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->n_threads = std::stoi(argv[i]); + } else if (arg == "-b" || arg == "--batch") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->n_batch = std::stoi(argv[i]); + } else if (arg == "--grad-acc") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->n_gradient_accumulation = std::max(1, std::stoi(argv[i])); + } else if (arg == "--sample-start") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->sample_start = std::string(argv[i]); + } else if (arg == "--escape") { + params->escape = true; + } else if (arg == "--include-sample-start") { + params->include_sample_start = true; + } else if (arg == "--overlapping-samples") { + params->overlapping_samples = true; + } else if (arg == "--fill-with-next-samples") { + params->fill_with_next_samples = true; + } else if (arg == "--separate-with-eos") { + params->separate_with_eos = true; + } else if (arg == "--separate-with-bos") { + params->separate_with_bos = true; + } else if (arg == "--no-separate-with-eos") { + params->separate_with_eos = false; + } else if (arg == "--no-separate-with-bos") { + params->separate_with_bos = false; + } else if (arg == "--sample-random-offsets") { + params->sample_random_offsets = true; + } else if (arg == "--force-reshuffle") { + params->force_reshuffle = true; + } else if (arg == "--no-flash") { + params->use_flash = false; + } else if (arg == "--use-flash") { + params->use_flash = true; + } else if (arg == "--no-checkpointing") { + params->use_checkpointing = false; + } else if (arg == "--use-checkpointing") { + params->use_checkpointing = true; + } else if (arg == "--warmup") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->warmup = std::stoi(argv[i]); + } else if (arg == "--cos-decay-steps") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->cos_decay_steps = std::stoi(argv[i]); + } else if (arg == "--cos-decay-restart") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->cos_decay_restart = std::stof(argv[i]); + } else if (arg == "--cos-decay-min") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->cos_decay_min = std::stof(argv[i]); + } else if (arg == "--enable-restart") { + params->enable_restart = true; + } else if (arg == "--disable-restart") { + params->enable_restart = false; + } else if (arg == "--opt-past") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->opt_past = std::stoi(argv[i]); + } else if (arg == "--opt-delta") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->opt_delta = std::stof(argv[i]); + } else if (arg == "--opt-max-no-improvement") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->opt_max_no_improvement = std::stoi(argv[i]); + } else if (arg == "--adam-epsf") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_eps_f = std::stof(argv[i]); + } else if (arg == "--epochs") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->n_epochs = std::stoi(argv[i]); + } else if (arg == "--adam-iter") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_n_iter = std::stoi(argv[i]); + } else if (arg == "--adam-alpha") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_alpha = std::stof(argv[i]); + } else if (arg == "--adam-min-alpha") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_min_alpha = std::stof(argv[i]); + } else if (arg == "--adam-decay") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_decay = std::stof(argv[i]); + } else if (arg == "--adam-decay-min-ndim") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_decay_min_ndim = std::stoi(argv[i]); + } else if (arg == "--adam-beta1") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_beta1 = std::stof(argv[i]); + } else if (arg == "--adam-beta2") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_beta2 = std::stof(argv[i]); + } else if (arg == "--adam-gclip") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->adam_gclip = std::stof(argv[i]); + } else if (arg == "-ngl" || arg == "--n-gpu-layers") { + if (++i >= argc) { + *invalid_param = true; + return true; + } +#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD + params->n_gpu_layers = std::stoi(argv[i]); +#else + fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); + fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); +#endif + } else if (arg == "-h" || arg == "--help") { + params->print_usage = true; + return true; + } else { + return false; + } + return true; +} + +void finish_processing_train_args(struct train_params_common * params) { + if (params->escape) { + process_escapes(params->sample_start); + } +} + +void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel) { + struct train_opt_callback_data * data = (struct train_opt_callback_data *) vdata; + struct train_params_common * params = data->params; + struct train_state * train = data->train; + struct ggml_opt_context * opt = train->opt; + int n_batch = params->n_batch; + int n_ctx = params->n_ctx; + + if (accum_step == 0) { + // time measurement + int64_t now = ggml_time_ms(); + if (now > data->last_time && opt->iter > data->first_iter) { + double dt = (double) (now - data->last_time); + if (data->millis_per_iter == 0.0) { + data->millis_per_iter = dt; + } else { + const double gain = 0.7; + data->millis_per_iter = data->millis_per_iter*(1.0-gain) + dt*gain; + } + } + + double remaining_millis = 0.0; + if (data->millis_per_iter > 0.0) { + const int n_iter = params->adam_n_iter; + const int done_iter = opt->iter - data->first_iter; + const int remaining_iter = n_iter - done_iter; + remaining_millis = remaining_iter * data->millis_per_iter; + } + + // file saving + const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every); + if (save_now) { + int new_iters = opt->iter - data->last_save_iter; + train->train_its += new_iters; + train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; + + if (data->save_cb) { + data->save_cb(data->save_data, train); + } + + data->last_save_iter = opt->iter; + } + + // exclude file saving from time measurement, by measuring last_time after saving + data->last_time = ggml_time_ms(); + + *sched = learning_schedule( + opt->iter, + params->warmup, + params->cos_decay_steps, + params->adam_alpha, + params->adam_min_alpha, + params->cos_decay_min, + params->cos_decay_restart, + params->enable_restart); + + int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f); + if (impr_plot > 0) impr_plot = 0; + if (std::isnan(opt->loss_before) || std::isnan(opt->loss_after)) impr_plot = 0; + printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f", + __func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count, + *sched, opt->loss_after); + + + if (data->millis_per_iter > 0) { + printf(" dt="); + print_duration(data->millis_per_iter); + printf(" eta="); + print_duration(remaining_millis); + } + + float improvement = opt->loss_before - opt->loss_after; + const float plot_scale = 10.0f; + int bar_len = (int)(1 + improvement*plot_scale + 0.5); + printf(" |"); + for (int i=0; i"); + printf("\n"); + } + + int64_t used_samples = get_example_targets_batch( + data->lctx, + data->tokens_input, + data->target_probs, + train->shuffle_next_sample, + data->shuffled_samples_offs, + data->shuffled_samples_begin, + data->shuffled_samples_size, + data->samples_count, + data->tokens_data, + data->tokens_size, + params->separate_with_eos, + params->separate_with_bos, + params->fill_with_next_samples, + params->sample_random_offsets); + + train->train_samples += used_samples; + train->shuffle_next_sample += used_samples; + + if (train->shuffle_next_sample >= train->shuffle_sample_count) { + ++train->train_epochs; + printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs); + // note: we may have used some samples from the current shuffling more than once + train->shuffle_rng_state_current = train->shuffle_rng_state_next; + train->shuffle_rng_state_next = shuffle_samples( + train->shuffle_rng_state_current, + data->shuffled_samples_offs, + data->shuffled_samples_begin, + data->shuffled_samples_size, + data->samples_begin, + data->samples_size, + data->samples_count); + train->shuffle_next_sample = 0; + } + + const bool last_epoch_reached = (params->n_epochs > 0 && (int64_t) train->train_epochs - data->first_epoch >= params->n_epochs); + if (last_epoch_reached) { + // allow optimization iteration at last epoch to be completed before canceling + if (data->iter_at_last_epoch < 0) { + data->iter_at_last_epoch = opt->iter; + } else if (opt->iter > data->iter_at_last_epoch) { + *cancel = true; + } + } +} diff --git a/applications/src/llama/common/train.h b/applications/src/llama/common/train.h new file mode 100644 index 000000000..263d940c0 --- /dev/null +++ b/applications/src/llama/common/train.h @@ -0,0 +1,233 @@ +// Various helper functions and utilities for training + +#pragma once + +#include +#include +#include + +#include "ggml.h" +#include "llama.h" + +#define LLAMA_TRAIN_MAX_NODES 16384 + +typedef std::string mt19937_state; + +struct train_state { + struct ggml_opt_context * opt; + + uint64_t train_its; + uint64_t train_samples; + uint64_t train_tokens; + uint64_t train_epochs; + + size_t shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes) + mt19937_state shuffle_rng_state_current; + mt19937_state shuffle_rng_state_next; + size_t shuffle_sample_count; + size_t shuffle_next_sample; +}; + +struct train_params_common { + const char * fn_train_data; + const char * fn_checkpoint_in; + const char * fn_checkpoint_out; + const char * pattern_fn_it; + const char * fn_latest; + + bool print_usage; + + int save_every; + + uint32_t seed; + + int n_ctx; + int n_threads; + int n_batch; + int n_gradient_accumulation; + int n_epochs; + int n_gpu_layers; + + bool custom_n_ctx; + + bool use_flash; + bool use_checkpointing; + + std::string sample_start; + bool include_sample_start; + bool escape; + bool overlapping_samples; + bool fill_with_next_samples; + bool separate_with_eos; + bool separate_with_bos; + bool sample_random_offsets; + + bool force_reshuffle; + + int warmup; + int cos_decay_steps; + float cos_decay_restart; + float cos_decay_min; + bool enable_restart; + + int opt_past; + float opt_delta; + int opt_max_no_improvement; + + int adam_n_iter; + float adam_alpha; + float adam_min_alpha; + float adam_decay; + int adam_decay_min_ndim; + float adam_beta1; + float adam_beta2; + float adam_gclip; + float adam_eps_f; +}; + +typedef void (*save_train_files_callback)(void * data, struct train_state * train); + +struct train_opt_callback_data { + struct train_params_common * params; + struct train_state * train; + save_train_files_callback save_cb; + void * save_data; + struct llama_context * lctx; + int last_save_iter; + llama_token * tokens_data; + size_t tokens_size; + size_t * samples_begin; + size_t * samples_size; + size_t * shuffled_samples_offs; + size_t * shuffled_samples_begin; + size_t * shuffled_samples_size; + size_t samples_count; + struct ggml_tensor * tokens_input; + struct ggml_tensor * target_probs; + int first_iter; + int first_epoch; + int iter_at_last_epoch; + int64_t last_time; + double millis_per_iter; +}; + +struct train_state * init_train_state(); +void free_train_state(struct train_state * state); + +struct train_params_common get_default_train_params_common(); +void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params); + +bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param); +void finish_processing_train_args(struct train_params_common * params); + +struct random_normal_distribution; +struct random_uniform_distribution; + +struct random_normal_distribution * init_random_normal_distribution (int seed, float mean, float std, float min, float max); +struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max); + +void free_random_normal_distribution (struct random_normal_distribution * rnd); +void free_random_uniform_distribution(struct random_uniform_distribution * rnd); + +struct ggml_tensor * randomize_tensor_normal (struct ggml_tensor * tensor, struct random_normal_distribution * rnd); +struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd); + +// generate random float in interval [0,1) +float frand(); +float frand_normal (struct random_normal_distribution * rnd); +float frand_uniform(struct random_uniform_distribution * rnd); + +int clamp (const int v, const int min, const int max); +float fclamp(const float v, const float min, const float max); + +void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0); +void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1); +void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2); +void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3); + +size_t tokenize_file( + struct llama_context * lctx, + const char * filename, + const std::string & sample_start, + bool include_sample_start, + bool overlapping_samples, + unsigned context_length, + std::vector & out_tokens, + std::vector & out_samples_begin, + std::vector & out_samples_size); + +int64_t get_example_targets_batch( + struct llama_context * lctx, + struct ggml_tensor * tokens_input, + struct ggml_tensor * target_probs, + int64_t example_id, + const size_t * samples_offs, + const size_t * samples_begin, + const size_t * samples_size, + size_t samples_count, + const llama_token * train_data, + size_t n_train_data, + bool separate_with_eos, + bool separate_with_bos, + bool fill_with_next_samples, + bool sample_random_offsets); + + +void mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state); +mt19937_state mt19937_get_state(const std::mt19937& rng); +mt19937_state mt19937_seed_to_state(unsigned seed); + +mt19937_state shuffle_samples( + const mt19937_state & rng_state, + size_t * shuffled_offs, + size_t * shuffled_begins, + size_t * shuffled_sizes, + const size_t * begins, + const size_t * sizes, + size_t count); + +size_t hash_combine(size_t h1, size_t h2); + +size_t compute_samples_hash( + const char* fn, + const size_t* samples_begin, + const size_t* samples_size, + size_t sample_count); + + +std::string replace_str(const char * s, const char * needle, const char * replacement); + +void print_duration(double milliseconds); + +float cosine_decay( + int64_t step, + int64_t decay_steps, + float minimum); + +float cosine_decay_restart( + int64_t step, + int64_t decay_steps, + float minimum, + float restart_step_mult); + +float learning_schedule( + int64_t step, + int64_t warmup_steps, + int64_t decay_steps, + float learning_rate, + float overall_minimum, + float cos_decay_minimum, + float cos_decay_restart_step_mult, + bool enable_restart); + +void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name); + +void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt); +void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt); + +bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train); +void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train); + +std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration); + +void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel); diff --git a/applications/src/llama/examples/main/main.cpp b/applications/src/llama/examples/main/main.cpp new file mode 100644 index 000000000..c096f110b --- /dev/null +++ b/applications/src/llama/examples/main/main.cpp @@ -0,0 +1,873 @@ +#include "common.h" + +#include "console.h" +#include "llama.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +#include +#include +#elif defined (_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#include +#endif + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +static llama_context ** g_ctx; +static llama_model ** g_model; +static gpt_params * g_params; +static std::vector * g_input_tokens; +static std::ostringstream * g_output_ss; +static std::vector * g_output_tokens; +static bool is_interacting = false; + + +static void write_logfile( + const llama_context * ctx, const gpt_params & params, const llama_model * model, + const std::vector & input_tokens, const std::string & output, + const std::vector & output_tokens +) { + if (params.logdir.empty()) { + return; + } + + const std::string timestamp = get_sortable_timestamp(); + + const bool success = create_directory_with_parents(params.logdir); + if (!success) { + fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n", + __func__, params.logdir.c_str()); + return; + } + + const std::string logfile_path = params.logdir + timestamp + ".yml"; + FILE * logfile = fopen(logfile_path.c_str(), "w"); + + if (logfile == NULL) { + fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str()); + return; + } + + fprintf(logfile, "binary: main\n"); + char model_desc[128]; + llama_model_desc(model, model_desc, sizeof(model_desc)); + dump_non_result_info_yaml(logfile, params, ctx, timestamp, input_tokens, model_desc); + + fprintf(logfile, "\n"); + fprintf(logfile, "######################\n"); + fprintf(logfile, "# Generation Results #\n"); + fprintf(logfile, "######################\n"); + fprintf(logfile, "\n"); + + dump_string_yaml_multiline(logfile, "output", output.c_str()); + dump_vector_int_yaml(logfile, "output_tokens", output_tokens); + + llama_dump_timing_info_yaml(logfile, ctx); + fclose(logfile); +} + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) +static void sigint_handler(int signo) { + if (signo == SIGINT) { + if (!is_interacting) { + is_interacting = true; + } else { + console::cleanup(); + printf("\n"); + llama_print_timings(*g_ctx); + write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); + _exit(130); + } + } +} +#endif + +static void llama_log_callback_logTee(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + LOG_TEE("%s", text); +} + +int main(int argc, char ** argv) { + gpt_params params; + g_params = ¶ms; + + if (!gpt_params_parse(argc, argv, params)) { + return 1; + } + llama_sampling_params & sparams = params.sparams; + +#ifndef LOG_DISABLE_LOGS + log_set_target(log_filename_generator("main", "log")); + LOG_TEE("Log start\n"); + log_dump_cmdline(argc, argv); + llama_log_set(llama_log_callback_logTee, nullptr); +#endif // LOG_DISABLE_LOGS + + // TODO: Dump params ? + //LOG("Params perplexity: %s\n", LOG_TOSTR(params.perplexity)); + + // save choice to use color for later + // (note for later: this is a slightly awkward choice) + console::init(params.simple_io, params.use_color); + atexit([]() { console::cleanup(); }); + + if (params.logits_all) { + printf("\n************\n"); + printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__); + printf("************\n\n"); + + return 0; + } + + if (params.embedding) { + printf("\n************\n"); + printf("%s: please use the 'embedding' tool for embedding calculations\n", __func__); + printf("************\n\n"); + + return 0; + } + + if (params.n_ctx != 0 && params.n_ctx < 8) { + LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__); + params.n_ctx = 8; + } + + if (params.rope_freq_base != 0.0) { + LOG_TEE("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base); + } + + if (params.rope_freq_scale != 0.0) { + LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale); + } + + LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); + LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); + + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); + } + + LOG_TEE("%s: seed = %u\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + if (params.random_prompt) { + params.prompt = gpt_random_prompt(rng); + } + + LOG("%s: llama backend init\n", __func__); + llama_backend_init(params.numa); + + llama_model * model; + llama_context * ctx; + llama_context * ctx_guidance = NULL; + g_model = &model; + g_ctx = &ctx; + + // load the model and apply lora adapter, if any + LOG("%s: load the model and apply lora adapter, if any\n", __func__); + std::tie(model, ctx) = llama_init_from_gpt_params(params); + if (sparams.cfg_scale > 1.f) { + struct llama_context_params lparams = llama_context_params_from_gpt_params(params); + ctx_guidance = llama_new_context_with_model(model, lparams); + } + + if (model == NULL) { + LOG_TEE("%s: error: unable to load model\n", __func__); + return 1; + } + + const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx = llama_n_ctx(ctx); + LOG("n_ctx: %d\n", n_ctx); + + if (n_ctx > n_ctx_train) { + LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n", + __func__, n_ctx_train, n_ctx); + } + + // print system information + { + LOG_TEE("\n"); + LOG_TEE("%s\n", get_system_info(params).c_str()); + } + + std::string path_session = params.path_prompt_cache; + std::vector session_tokens; + + if (!path_session.empty()) { + LOG_TEE("%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str()); + + // fopen to check for existing session + FILE * fp = std::fopen(path_session.c_str(), "rb"); + if (fp != NULL) { + std::fclose(fp); + + session_tokens.resize(n_ctx); + size_t n_token_count_out = 0; + if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) { + LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str()); + return 1; + } + session_tokens.resize(n_token_count_out); + llama_set_rng_seed(ctx, params.seed); + + LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size()); + } else { + LOG_TEE("%s: session file does not exist, will create\n", __func__); + } + } + + const bool add_bos = llama_should_add_bos_token(model); + LOG("add_bos: %d\n", add_bos); + + std::vector embd_inp; + + if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) { + LOG("tokenize the prompt\n"); + if (params.chatml) { + params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>"; + } + embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + } else { + LOG("use session tokens\n"); + embd_inp = session_tokens; + } + + LOG("prompt: \"%s\"\n", log_tostr(params.prompt)); + LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + + // Should not run without any tokens + if (embd_inp.empty()) { + embd_inp.push_back(llama_token_bos(model)); + LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + } + + // Tokenize negative prompt + std::vector guidance_inp; + int guidance_offset = 0; + int original_prompt_len = 0; + if (ctx_guidance) { + LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt)); + + guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos, true); + LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str()); + + std::vector original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str()); + + original_prompt_len = original_inp.size(); + guidance_offset = (int)guidance_inp.size() - original_prompt_len; + LOG("original_prompt_len: %s", log_tostr(original_prompt_len)); + LOG("guidance_offset: %s", log_tostr(guidance_offset)); + } + + if ((int) embd_inp.size() > n_ctx - 4) { + LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); + return 1; + } + + // debug message about similarity of saved session, if applicable + size_t n_matching_session_tokens = 0; + if (!session_tokens.empty()) { + for (llama_token id : session_tokens) { + if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) { + break; + } + n_matching_session_tokens++; + } + if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) { + LOG_TEE("%s: using full prompt from session file\n", __func__); + } else if (n_matching_session_tokens >= embd_inp.size()) { + LOG_TEE("%s: session file has exact match for prompt!\n", __func__); + } else if (n_matching_session_tokens < (embd_inp.size() / 2)) { + LOG_TEE("%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n", + __func__, n_matching_session_tokens, embd_inp.size()); + } else { + LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n", + __func__, n_matching_session_tokens, embd_inp.size()); + } + + // remove any "future" tokens that we might have inherited from the previous session + llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1); + } + + LOGLN( + "recalculate the cached logits (check): embd_inp.empty() %s, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu, embd_inp.size() %zu", + log_tostr(embd_inp.empty()), n_matching_session_tokens, embd_inp.size(), session_tokens.size(), embd_inp.size()); + + // if we will use the cache for the full prompt without reaching the end of the cache, force + // reevaluation of the last token token to recalculate the cached logits + if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) { + LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1); + + session_tokens.resize(embd_inp.size() - 1); + } + + // number of tokens to keep when resetting context + if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) { + params.n_keep = (int)embd_inp.size(); + } + + // prefix & suffix for instruct mode + const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true); + const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true); + + LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str()); + LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str()); + + // chatml prefix & suffix + const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", add_bos, true); + const auto cml_sfx = ::llama_tokenize(ctx, "<|im_end|>\n<|im_start|>assistant\n", false, true); + + LOG("cml_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_pfx).c_str()); + LOG("cml_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_sfx).c_str()); + + // in instruct mode, we inject a prefix and a suffix to each input by the user + if (params.instruct) { + params.interactive_first = true; + params.antiprompt.push_back("### Instruction:\n\n"); + } + // similar for chatml mode + else if (params.chatml) { + params.interactive_first = true; + params.antiprompt.push_back("<|im_start|>user\n"); + } + + // enable interactive mode if interactive start is specified + if (params.interactive_first) { + params.interactive = true; + } + + if (params.verbose_prompt) { + LOG_TEE("\n"); + LOG_TEE("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + LOG_TEE("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + for (int i = 0; i < (int) embd_inp.size(); i++) { + LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str()); + } + + if (ctx_guidance) { + LOG_TEE("\n"); + LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str()); + LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); + for (int i = 0; i < (int) guidance_inp.size(); i++) { + LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str()); + } + } + + if (params.n_keep > 0) { + LOG_TEE("%s: static prompt based on n_keep: '", __func__); + for (int i = 0; i < params.n_keep; i++) { + LOG_TEE("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str()); + } + LOG_TEE("'\n"); + } + LOG_TEE("\n"); + } + + if (params.interactive) { +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = sigint_handler; + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); +#elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif + + LOG_TEE("%s: interactive mode on.\n", __func__); + + if (!params.antiprompt.empty()) { + for (const auto & antiprompt : params.antiprompt) { + LOG_TEE("Reverse prompt: '%s'\n", antiprompt.c_str()); + if (params.verbose_prompt) { + auto tmp = ::llama_tokenize(ctx, antiprompt, false, true); + for (int i = 0; i < (int) tmp.size(); i++) { + LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str()); + } + } + } + } + + if (params.input_prefix_bos) { + LOG_TEE("Input prefix with BOS\n"); + } + + if (!params.input_prefix.empty()) { + LOG_TEE("Input prefix: '%s'\n", params.input_prefix.c_str()); + if (params.verbose_prompt) { + auto tmp = ::llama_tokenize(ctx, params.input_prefix, true, true); + for (int i = 0; i < (int) tmp.size(); i++) { + LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str()); + } + } + } + + if (!params.input_suffix.empty()) { + LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); + if (params.verbose_prompt) { + auto tmp = ::llama_tokenize(ctx, params.input_suffix, false, true); + for (int i = 0; i < (int) tmp.size(); i++) { + LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str()); + } + } + } + } + LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); + LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str()); + LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + LOG_TEE("\n\n"); + + if (params.interactive) { + const char *control_message; + if (params.multiline_input) { + control_message = " - To return control to LLaMa, end your input with '\\'.\n" + " - To return control without starting a new line, end your input with '/'.\n"; + } else { + control_message = " - Press Return to return control to LLaMa.\n" + " - To return control without starting a new line, end your input with '/'.\n" + " - If you want to submit another line, end your input with '\\'.\n"; + } + LOG_TEE("== Running in interactive mode. ==\n"); +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) + LOG_TEE( " - Press Ctrl+C to interject at any time.\n"); +#endif + LOG_TEE( "%s\n", control_message); + + is_interacting = params.interactive_first; + } + + bool is_antiprompt = false; + bool input_echo = true; + bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < embd_inp.size(); + + int n_past = 0; + int n_remain = params.n_predict; + int n_consumed = 0; + int n_session_consumed = 0; + int n_past_guidance = 0; + + std::vector input_tokens; g_input_tokens = &input_tokens; + std::vector output_tokens; g_output_tokens = &output_tokens; + std::ostringstream output_ss; g_output_ss = &output_ss; + + // the first thing we will do is to output the prompt, so set color accordingly + console::set_display(console::prompt); + + std::vector embd; + std::vector embd_guidance; + + struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + + while ((n_remain != 0 && !is_antiprompt) || params.interactive) { + // predict + if (!embd.empty()) { + // Note: n_ctx - 4 here is to match the logic for commandline prompt handling via + // --prompt or --file which uses the same value. + int max_embd_size = n_ctx - 4; + + // Ensure the input doesn't exceed the context size by truncating embd if necessary. + if ((int) embd.size() > max_embd_size) { + const int skipped_tokens = (int) embd.size() - max_embd_size; + embd.resize(max_embd_size); + + console::set_display(console::error); + printf("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); + console::set_display(console::reset); + fflush(stdout); + } + + // infinite text generation via context swapping + // if we run out of context: + // - take the n_keep first tokens from the original prompt (via n_past) + // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches + if (n_past + (int) embd.size() + std::max(0, guidance_offset) > n_ctx) { + if (params.n_predict == -2) { + LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); + break; + } + + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left/2; + + LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + n_past, n_left, n_ctx, params.n_keep, n_discard); + + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + + n_past -= n_discard; + + if (ctx_guidance) { + n_past_guidance -= n_discard; + } + + LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + + LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + + LOG("clear session path\n"); + path_session.clear(); + } + + // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past) + if (n_session_consumed < (int) session_tokens.size()) { + size_t i = 0; + for ( ; i < embd.size(); i++) { + if (embd[i] != session_tokens[n_session_consumed]) { + session_tokens.resize(n_session_consumed); + break; + } + + n_past++; + n_session_consumed++; + + if (n_session_consumed >= (int) session_tokens.size()) { + ++i; + break; + } + } + if (i > 0) { + embd.erase(embd.begin(), embd.begin() + i); + } + } + + // evaluate tokens in batches + // embd is typically prepared beforehand to fit within a batch, but not always + if (ctx_guidance) { + int input_size = 0; + llama_token * input_buf = NULL; + + if (n_past_guidance < (int) guidance_inp.size()) { + // Guidance context should have the same data with these modifications: + // + // * Replace the initial prompt + // * Shift everything by guidance_offset + embd_guidance = guidance_inp; + if (embd.begin() + original_prompt_len < embd.end()) { + embd_guidance.insert( + embd_guidance.end(), + embd.begin() + original_prompt_len, + embd.end() + ); + } + + input_buf = embd_guidance.data(); + input_size = embd_guidance.size(); + + LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str()); + } else { + input_buf = embd.data(); + input_size = embd.size(); + } + + for (int i = 0; i < input_size; i += params.n_batch) { + int n_eval = std::min(input_size - i, params.n_batch); + if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) { + LOG_TEE("%s : failed to eval\n", __func__); + return 1; + } + + n_past_guidance += n_eval; + } + } + + for (int i = 0; i < (int) embd.size(); i += params.n_batch) { + int n_eval = (int) embd.size() - i; + if (n_eval > params.n_batch) { + n_eval = params.n_batch; + } + + LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + + if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { + LOG_TEE("%s : failed to eval\n", __func__); + return 1; + } + + n_past += n_eval; + + LOG("n_past = %d\n", n_past); + } + + if (!embd.empty() && !path_session.empty()) { + session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); + n_session_consumed = session_tokens.size(); + } + } + + embd.clear(); + embd_guidance.clear(); + + if ((int) embd_inp.size() <= n_consumed && !is_interacting) { + // optionally save the session on first sample (for faster prompt loading next time) + if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) { + need_to_save_session = false; + llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); + + LOG("saved session to %s\n", path_session.c_str()); + } + + const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); + + llama_sampling_accept(ctx_sampling, ctx, id, true); + + LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); + + embd.push_back(id); + + // echo this to console + input_echo = true; + + // decrement remaining sampling budget + --n_remain; + + LOG("n_remain: %d\n", n_remain); + } else { + // some user input remains from prompt or interaction, forward it to processing + LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); + while ((int) embd_inp.size() > n_consumed) { + embd.push_back(embd_inp[n_consumed]); + + // push the prompt in the sampling context in order to apply repetition penalties later + // for the prompt, we don't apply grammar rules + llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false); + + ++n_consumed; + if ((int) embd.size() >= params.n_batch) { + break; + } + } + } + + // display text + if (input_echo) { + for (auto id : embd) { + const std::string token_str = llama_token_to_piece(ctx, id); + printf("%s", token_str.c_str()); + + if (embd.size() > 1) { + input_tokens.push_back(id); + } else { + output_tokens.push_back(id); + output_ss << token_str; + } + } + fflush(stdout); + } + // reset color to default if there is no pending user input + if (input_echo && (int) embd_inp.size() == n_consumed) { + console::set_display(console::reset); + } + + // if not currently processing queued inputs; + if ((int) embd_inp.size() <= n_consumed) { + // check for reverse prompt in the last n_prev tokens + if (!params.antiprompt.empty()) { + const int n_prev = 32; + const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev); + + is_antiprompt = false; + // Check if each of the reverse prompts appears at the end of the output. + // If we're not running interactively, the reverse prompt might be tokenized with some following characters + // so we'll compensate for that by widening the search window a bit. + for (std::string & antiprompt : params.antiprompt) { + size_t extra_padding = params.interactive ? 0 : 2; + size_t search_start_pos = last_output.length() > static_cast(antiprompt.length() + extra_padding) + ? last_output.length() - static_cast(antiprompt.length() + extra_padding) + : 0; + + if (last_output.find(antiprompt, search_start_pos) != std::string::npos) { + if (params.interactive) { + is_interacting = true; + } + is_antiprompt = true; + break; + } + } + + if (is_antiprompt) { + LOG("found antiprompt: %s\n", last_output.c_str()); + } + } + + // deal with end of text token in interactive mode + if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) { + LOG("found EOS token\n"); + + if (params.interactive) { + if (!params.antiprompt.empty()) { + // tokenize and inject first reverse prompt + const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true); + embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); + is_antiprompt = true; + } + + is_interacting = true; + printf("\n"); + } else if (params.instruct || params.chatml) { + is_interacting = true; + } + } + + if (n_past > 0 && is_interacting) { + LOG("waiting for user input\n"); + + if (params.instruct || params.chatml) { + printf("\n> "); + } + + if (params.input_prefix_bos) { + LOG("adding input prefix BOS token\n"); + embd_inp.push_back(llama_token_bos(model)); + } + + std::string buffer; + if (!params.input_prefix.empty()) { + LOG("appending input prefix: '%s'\n", params.input_prefix.c_str()); + printf("%s", params.input_prefix.c_str()); + } + + // color user input only + console::set_display(console::user_input); + + std::string line; + bool another_line = true; + do { + another_line = console::readline(line, params.multiline_input); + buffer += line; + } while (another_line); + + // done taking input, reset color + console::set_display(console::reset); + + // Add tokens to embd only if the input buffer is non-empty + // Entering a empty line lets the user pass control back + if (buffer.length() > 1) { + // append input suffix if any + if (!params.input_suffix.empty()) { + LOG("appending input suffix: '%s'\n", params.input_suffix.c_str()); + printf("%s", params.input_suffix.c_str()); + } + + LOG("buffer: '%s'\n", buffer.c_str()); + + const size_t original_size = embd_inp.size(); + + // instruct mode: insert instruction prefix + if (params.instruct && !is_antiprompt) { + LOG("inserting instruction prefix\n"); + n_consumed = embd_inp.size(); + embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); + } + // chatml mode: insert user chat prefix + if (params.chatml && !is_antiprompt) { + LOG("inserting chatml prefix\n"); + n_consumed = embd_inp.size(); + embd_inp.insert(embd_inp.end(), cml_pfx.begin(), cml_pfx.end()); + } + if (params.escape) { + process_escapes(buffer); + } + + const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true); + const auto line_inp = ::llama_tokenize(ctx, buffer, false, false); + const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true); + LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str()); + + embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end()); + embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); + embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end()); + + // instruct mode: insert response suffix + if (params.instruct) { + LOG("inserting instruction suffix\n"); + embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + } + // chatml mode: insert assistant chat suffix + if (params.chatml) { + LOG("inserting chatml suffix\n"); + embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end()); + } + + for (size_t i = original_size; i < embd_inp.size(); ++i) { + const llama_token token = embd_inp[i]; + output_tokens.push_back(token); + output_ss << llama_token_to_piece(ctx, token); + } + + n_remain -= line_inp.size(); + LOG("n_remain: %d\n", n_remain); + } else { + LOG("empty line, passing control back\n"); + } + + input_echo = false; // do not echo this again + } + + if (n_past > 0) { + if (is_interacting) { + llama_sampling_reset(ctx_sampling); + } + is_interacting = false; + } + } + + // end of text token + if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive || params.chatml)) { + LOG_TEE(" [end of text]\n"); + break; + } + + // In interactive mode, respect the maximum number of tokens and drop back to user input when reached. + // We skip this logic when n_predict == -1 (infinite) or -2 (stop at context size). + if (params.interactive && n_remain <= 0 && params.n_predict >= 0) { + n_remain = params.n_predict; + is_interacting = true; + } + } + + if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) { + LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str()); + llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); + } + + llama_print_timings(ctx); + write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); + + if (ctx_guidance) { llama_free(ctx_guidance); } + llama_free(ctx); + llama_free_model(model); + + llama_sampling_free(ctx_sampling); + llama_backend_free(); + +#ifndef LOG_DISABLE_LOGS + LOG_TEE("Log end\n"); +#endif // LOG_DISABLE_LOGS + + return 0; +} diff --git a/applications/src/llama/ggml-alloc.c b/applications/src/llama/ggml-alloc.c new file mode 100644 index 000000000..d3049efb4 --- /dev/null +++ b/applications/src/llama/ggml-alloc.c @@ -0,0 +1,802 @@ +#include "ggml-alloc.h" +#include "ggml-backend-impl.h" +#include "ggml.h" +#include "ggml-impl.h" +#include +#include +#include +#include +#include +#include + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MAX_FREE_BLOCKS 256 + +//#define GGML_ALLOCATOR_DEBUG + +//#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__) +#define AT_PRINTF(...) + +// TODO: GGML_PAD ? +static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) { + assert(alignment && !(alignment & (alignment - 1))); // power of 2 + size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment; + return offset + align; +} + +struct free_block { + void * addr; + size_t size; +}; + +struct ggml_tallocr { + struct ggml_backend_buffer * buffer; + bool buffer_owned; + void * base; + size_t alignment; + + int n_free_blocks; + struct free_block free_blocks[MAX_FREE_BLOCKS]; + + size_t max_size; + + bool measure; + +#ifdef GGML_ALLOCATOR_DEBUG + struct ggml_tensor * allocated_tensors[1024]; +#endif +}; + +#ifdef GGML_ALLOCATOR_DEBUG +static void add_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) { + for (int i = 0; i < 1024; i++) { + if (alloc->allocated_tensors[i] == NULL) { + alloc->allocated_tensors[i] = tensor; + return; + } + } + GGML_ASSERT(!"out of allocated_tensors"); +} +static void remove_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) { + for (int i = 0; i < 1024; i++) { + if (alloc->allocated_tensors[i] == tensor || + (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) { + alloc->allocated_tensors[i] = NULL; + return; + } + } + printf("tried to free tensor %s not found\n", tensor->name); + GGML_ASSERT(!"tensor not found"); +} +#endif + +// check if a tensor is allocated by this buffer +static bool ggml_tallocr_is_own(ggml_tallocr_t alloc, const struct ggml_tensor * tensor) { + return tensor->buffer == alloc->buffer; +} + +static bool ggml_is_view(struct ggml_tensor * t) { + return t->view_src != NULL; +} + +void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) { + GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources + GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated + + size_t size = ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor); + size = aligned_offset(NULL, size, alloc->alignment); + + AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size); + + size_t max_avail = 0; + + // find the best fitting free block besides the last block + int best_fit_block = -1; + size_t best_fit_size = SIZE_MAX; + for (int i = 0; i < alloc->n_free_blocks - 1; i++) { + struct free_block * block = &alloc->free_blocks[i]; + max_avail = MAX(max_avail, block->size); + if (block->size >= size && block->size <= best_fit_size) { + best_fit_block = i; + best_fit_size = block->size; + } + } + + AT_PRINTF("block %d\n", best_fit_block); + + if (best_fit_block == -1) { + // the last block is our last resort + struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1]; + max_avail = MAX(max_avail, block->size); + if (block->size >= size) { + best_fit_block = alloc->n_free_blocks - 1; + } else { + fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n", + __func__, size, max_avail); + GGML_ASSERT(!"not enough space in the buffer"); + return; + } + } + struct free_block * block = &alloc->free_blocks[best_fit_block]; + void * addr = block->addr; + block->addr = (char*)block->addr + size; + block->size -= size; + if (block->size == 0) { + // remove block if empty + alloc->n_free_blocks--; + for (int j = best_fit_block; j < alloc->n_free_blocks; j++) { + alloc->free_blocks[j] = alloc->free_blocks[j+1]; + } + } + + tensor->data = addr; + tensor->buffer = alloc->buffer; + if (!alloc->measure) { + ggml_backend_buffer_init_tensor(alloc->buffer, tensor); + } + +#ifdef GGML_ALLOCATOR_DEBUG + add_allocated_tensor(alloc, tensor); + size_t cur_max = (char*)addr - (char*)alloc->base + size; + if (cur_max > alloc->max_size) { + printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0); + for (int i = 0; i < 1024; i++) { + if (alloc->allocated_tensors[i]) { + printf("%s (%.2f MB) ", alloc->allocated_tensors[i]->name, ggml_nbytes(alloc->allocated_tensors[i]) / 1024.0 / 1024.0); + } + } + printf("\n"); + } +#endif + + alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->base + size); +} + +// this is a very naive implementation, but for our case the number of free blocks should be very small +static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) { + if (ggml_tallocr_is_own(alloc, tensor) == false) { + // the tensor was not allocated in this buffer + // this can happen because the graph allocator will try to free weights and other tensors from different buffers + // the easiest way to deal with this is just to ignore it + // AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer); + return; + } + + void * ptr = tensor->data; + + size_t size = ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor); + size = aligned_offset(NULL, size, alloc->alignment); + AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks); + +#ifdef GGML_ALLOCATOR_DEBUG + remove_allocated_tensor(alloc, tensor); +#endif + + // see if we can merge with an existing block + for (int i = 0; i < alloc->n_free_blocks; i++) { + struct free_block * block = &alloc->free_blocks[i]; + // check if ptr is at the end of the block + if ((char*)block->addr + block->size == ptr) { + block->size += size; + // check if we can merge with the next block + if (i < alloc->n_free_blocks - 1 && (char*)block->addr + block->size == alloc->free_blocks[i+1].addr) { + block->size += alloc->free_blocks[i+1].size; + alloc->n_free_blocks--; + for (int j = i+1; j < alloc->n_free_blocks; j++) { + alloc->free_blocks[j] = alloc->free_blocks[j+1]; + } + } + return; + } + // check if ptr is at the beginning of the block + if ((char*)ptr + size == block->addr) { + block->addr = ptr; + block->size += size; + // check if we can merge with the previous block + if (i > 0 && (char*)alloc->free_blocks[i-1].addr + alloc->free_blocks[i-1].size == block->addr) { + alloc->free_blocks[i-1].size += block->size; + alloc->n_free_blocks--; + for (int j = i; j < alloc->n_free_blocks; j++) { + alloc->free_blocks[j] = alloc->free_blocks[j+1]; + } + } + return; + } + } + // otherwise, add a new block + GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks"); + // insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster) + int insert_pos = 0; + while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].addr < ptr) { + insert_pos++; + } + // shift all blocks from insert_pos onward to make room for the new block + for (int i = alloc->n_free_blocks; i > insert_pos; i--) { + alloc->free_blocks[i] = alloc->free_blocks[i-1]; + } + // insert the new block + alloc->free_blocks[insert_pos].addr = ptr; + alloc->free_blocks[insert_pos].size = size; + alloc->n_free_blocks++; +} + +void ggml_tallocr_reset(ggml_tallocr_t alloc) { + alloc->n_free_blocks = 1; + size_t align_offset = aligned_offset(alloc->base, 0, alloc->alignment); + alloc->free_blocks[0].addr = (char *)alloc->base + align_offset; + + if (alloc->measure) { + alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows + } else { + alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset; + } +} + +ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) { + struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(data, size); + + ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr)); + + *alloc = (struct ggml_tallocr) { + /*.buffer = */ buffer, + /*.buffer_owned = */ true, + /*.base = */ ggml_backend_buffer_get_base(buffer), + /*.alignment = */ alignment, + /*.n_free_blocks = */ 0, + /*.free_blocks = */ {{0}}, + /*.max_size = */ 0, + /*.measure = */ false, +#ifdef GGML_ALLOCATOR_DEBUG + /*.allocated_tensors = */ {0}, +#endif + }; + + ggml_tallocr_reset(alloc); + + return alloc; +} + +ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment) { + ggml_tallocr_t alloc = ggml_tallocr_new((void *)0x1000, SIZE_MAX/2, alignment); + alloc->measure = true; + + return alloc; +} + +ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend) { + // create a backend buffer to get the correct tensor allocation sizes + ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, 1); + + // TODO: move alloc initialization to a common ggml_tallocr_new_impl function + ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer); + alloc->buffer_owned = true; + alloc->measure = true; + ggml_tallocr_reset(alloc); + return alloc; +} + +ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size) { + ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, size); + ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer); + alloc->buffer_owned = true; + return alloc; +} + +ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer) { + ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr)); + + *alloc = (struct ggml_tallocr) { + /*.buffer = */ buffer, + /*.buffer_owned = */ false, + /*.base = */ ggml_backend_buffer_get_base(buffer), + /*.alignment = */ ggml_backend_buffer_get_alignment(buffer), + /*.n_free_blocks = */ 0, + /*.free_blocks = */ {{0}}, + /*.max_size = */ 0, + /*.measure = */ false, +#ifdef GGML_ALLOCATOR_DEBUG + /*.allocated_tensors = */ {0}, +#endif + }; + + ggml_tallocr_reset(alloc); + + return alloc; +} + +struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t alloc) { + return alloc->buffer; +} + +void ggml_tallocr_free(ggml_tallocr_t alloc) { + if (alloc == NULL) { + return; + } + + if (alloc->buffer_owned) { + ggml_backend_buffer_free(alloc->buffer); + } + free(alloc); +} + +bool ggml_tallocr_is_measure(ggml_tallocr_t alloc) { + return alloc->measure; +} + +size_t ggml_tallocr_max_size(ggml_tallocr_t alloc) { + return alloc->max_size; +} + +// graph allocator + +struct hash_node { + int n_children; + int n_views; +}; + +struct ggml_gallocr { + ggml_tallocr_t talloc; + struct ggml_hash_set hash_set; + struct hash_node * hash_values; + size_t hash_values_size; + ggml_tallocr_t * hash_allocs; + int * parse_seq; + int parse_seq_len; +}; + +ggml_gallocr_t ggml_gallocr_new(void) { + ggml_gallocr_t galloc = (ggml_gallocr_t)malloc(sizeof(struct ggml_gallocr)); + + *galloc = (struct ggml_gallocr) { + /*.talloc = */ NULL, + /*.hash_set = */ {0}, + /*.hash_values = */ NULL, + /*.hash_values_size = */ 0, + /*.hash_allocs = */ NULL, + /*.parse_seq = */ NULL, + /*.parse_seq_len = */ 0, + }; + + return galloc; +} + +void ggml_gallocr_free(ggml_gallocr_t galloc) { + if (galloc == NULL) { + return; + } + + if (galloc->hash_set.keys != NULL) { + free(galloc->hash_set.keys); + } + if (galloc->hash_values != NULL) { + free(galloc->hash_values); + } + if (galloc->hash_allocs != NULL) { + free(galloc->hash_allocs); + } + if (galloc->parse_seq != NULL) { + free(galloc->parse_seq); + } + free(galloc); +} + +void ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n) { + free(galloc->parse_seq); + galloc->parse_seq = malloc(sizeof(int) * n); + + for (int i = 0; i < n; i++) { + galloc->parse_seq[i] = list[i]; + } + galloc->parse_seq_len = n; +} + +static struct hash_node * hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) { + size_t i = ggml_hash_find_or_insert(galloc->hash_set, t); + return &galloc->hash_values[i]; +} + +static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { + if (a->type != b->type) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (a->ne[i] != b->ne[i]) { + return false; + } + if (a->nb[i] != b->nb[i]) { + return false; + } + } + return true; +} + +static bool ggml_op_can_inplace(enum ggml_op op) { + switch (op) { + case GGML_OP_SCALE: + case GGML_OP_DIAG_MASK_ZERO: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_ADD: + case GGML_OP_ADD1: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_LOG: + case GGML_OP_UNARY: + case GGML_OP_ROPE: + case GGML_OP_RMS_NORM: + case GGML_OP_SOFT_MAX: + return true; + + default: + return false; + } +} + +static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * node) { + if (galloc->talloc != NULL) { + return galloc->talloc; + } + + return galloc->hash_allocs[ggml_hash_find_or_insert(galloc->hash_set, node)]; +} + +static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) { + ggml_tallocr_t alloc = node_tallocr(galloc, view); + + GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL); + if (update_backend) { + view->backend = view->view_src->backend; + } + view->buffer = view->view_src->buffer; + view->data = (char *)view->view_src->data + view->view_offs; + + // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend + // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras + assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft); + + if (!alloc->measure) { + ggml_backend_buffer_init_tensor(alloc->buffer, view); + } +} + +static void allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node) { + ggml_tallocr_t alloc = node_tallocr(galloc, node); + + if (node->data == NULL) { + if (ggml_is_view(node)) { + init_view(galloc, node, true); + } else { + // see if we can reuse a parent's buffer (inplace) + if (ggml_op_can_inplace(node->op)) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + struct ggml_tensor * parent = node->src[i]; + if (parent == NULL) { + break; + } + + // if the node's data is external, then we cannot re-use it + if (ggml_tallocr_is_own(alloc, parent) == false) { + AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data); + continue; + } + + struct hash_node * p_hn = hash_get(galloc, parent); + if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) { + if (ggml_is_view(parent)) { + struct ggml_tensor * view_src = parent->view_src; + struct hash_node * view_src_hn = hash_get(galloc, view_src); + if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) { + // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite + // the parent's data that it will need later (same layout requirement). the problem is that then + // we cannot free the tensor because the original address of the allocation is lost. + // adding a view_src pointer to the tensor would solve this and simplify the code dealing with views + // for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data) + AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name); + node->view_src = view_src; + view_src_hn->n_views += 1; + init_view(galloc, node, false); + return; + } + } else { + AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name); + node->view_src = parent; + p_hn->n_views += 1; + init_view(galloc, node, false); + return; + } + } + } + } + ggml_tallocr_alloc(alloc, node); + } + } +} + +static void free_node(ggml_gallocr_t galloc, struct ggml_tensor * node) { + ggml_tallocr_t alloc = node_tallocr(galloc, node); + + ggml_tallocr_free_tensor(alloc, node); +} + +static void ggml_tallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * gf) { + const int * parse_seq = galloc->parse_seq; + int parse_seq_len = galloc->parse_seq_len; + + // count number of children and views + for (int i = 0; i < gf->n_nodes; i++) { + struct ggml_tensor * node = gf->nodes[i]; + + if (ggml_is_view(node)) { + struct ggml_tensor * view_src = node->view_src; + hash_get(galloc, view_src)->n_views += 1; + if (node->buffer == NULL && node->data != NULL) { + // view of a pre-allocated tensor, didn't call init_view() yet + init_view(galloc, node, true); + } + } + + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * parent = node->src[j]; + if (parent == NULL) { + break; + } + hash_get(galloc, parent)->n_children += 1; + if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) { + init_view(galloc, parent, true); + } + } + } + + // allocate tensors + // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers + int last_barrier_pos = 0; + int n_nodes = parse_seq_len ? parse_seq_len : gf->n_nodes; + + for (int ind = 0; ind < n_nodes; ind++) { + // allocate a node if there is no parse_seq or this is not a barrier + if (parse_seq_len == 0 || parse_seq[ind] != -1) { + int i = parse_seq_len ? parse_seq[ind] : ind; + struct ggml_tensor * node = gf->nodes[i]; + + // allocate parents (leafs) + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * parent = node->src[j]; + if (parent == NULL) { + break; + } + allocate_node(galloc, parent); + } + + // allocate node + allocate_node(galloc, node); + + AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name); + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * parent = node->src[j]; + if (parent == NULL) { + break; + } + AT_PRINTF("%s", parent->name); + if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) { + AT_PRINTF(", "); + } + } + AT_PRINTF("\n"); + } + + // update parents + // update immediately if there is no parse_seq + // update only at barriers if there is parse_seq + if ((parse_seq_len == 0) || parse_seq[ind] == -1) { + int update_start = parse_seq_len ? last_barrier_pos : ind; + int update_end = parse_seq_len ? ind : ind + 1; + for (int i = update_start; i < update_end; i++) { + int node_i = parse_seq_len ? parse_seq[i] : i; + struct ggml_tensor * node = gf->nodes[node_i]; + + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * parent = node->src[j]; + if (parent == NULL) { + break; + } + struct hash_node * p_hn = hash_get(galloc, parent); + p_hn->n_children -= 1; + + //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views); + + if (p_hn->n_children == 0 && p_hn->n_views == 0) { + if (ggml_is_view(parent)) { + struct ggml_tensor * view_src = parent->view_src; + struct hash_node * view_src_hn = hash_get(galloc, view_src); + view_src_hn->n_views -= 1; + AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views); + if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0) { + free_node(galloc, view_src); + } + } + else { + free_node(galloc, parent); + } + } + } + } + AT_PRINTF("\n"); + if (parse_seq_len) { + last_barrier_pos = ind + 1; + } + } + } +} + +size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph) { + size_t hash_size = graph->visited_hash_table.size; + + // check if the hash table is initialized and large enough + if (galloc->hash_set.size < hash_size) { + if (galloc->hash_set.keys != NULL) { + free(galloc->hash_set.keys); + } + if (galloc->hash_values != NULL) { + free(galloc->hash_values); + } + galloc->hash_set.keys = malloc(sizeof(struct ggml_tensor *) * hash_size); + galloc->hash_set.size = hash_size; + galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size); + } + + // reset hash table + memset(galloc->hash_set.keys, 0, sizeof(struct ggml_tensor *) * hash_size); + memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size); + + galloc->talloc = talloc; + ggml_tallocr_alloc_graph_impl(galloc, graph); + galloc->talloc = NULL; + + size_t max_size = ggml_tallocr_max_size(talloc); + + return max_size; +} + +void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, struct ggml_hash_set hash_set, ggml_tallocr_t * hash_node_talloc) { + const size_t hash_size = hash_set.size; + + GGML_ASSERT(hash_size >= (size_t)(graph->n_nodes + graph->n_leafs)); + + galloc->talloc = NULL; + + // alloc hash_values if needed + if (galloc->hash_values == NULL || galloc->hash_values_size < hash_size) { + free(galloc->hash_values); + galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size); + galloc->hash_values_size = hash_size; + } + + // free hash_set.keys if needed + if (galloc->hash_set.keys != NULL) { + free(galloc->hash_set.keys); + } + galloc->hash_set = hash_set; + + // reset hash values + memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size); + + galloc->hash_allocs = hash_node_talloc; + + ggml_tallocr_alloc_graph_impl(galloc, graph); + + // remove unowned resources + galloc->hash_set.keys = NULL; + galloc->hash_allocs = NULL; +} + +// legacy API wrapper + +struct ggml_allocr { + ggml_tallocr_t talloc; + ggml_gallocr_t galloc; +}; + +static ggml_allocr_t ggml_allocr_new_impl(ggml_tallocr_t talloc) { + ggml_allocr_t alloc = (ggml_allocr_t)malloc(sizeof(struct ggml_allocr)); + *alloc = (struct ggml_allocr) { + /*.talloc = */ talloc, + /*.galloc = */ ggml_gallocr_new(), + }; + return alloc; +} + +ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment) { + return ggml_allocr_new_impl(ggml_tallocr_new(data, size, alignment)); +} + +ggml_allocr_t ggml_allocr_new_measure(size_t alignment) { + return ggml_allocr_new_impl(ggml_tallocr_new_measure(alignment)); +} + +ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer) { + return ggml_allocr_new_impl(ggml_tallocr_new_from_buffer(buffer)); +} + +ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size) { + return ggml_allocr_new_impl(ggml_tallocr_new_from_backend(backend, size)); +} + +ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend) { + return ggml_allocr_new_impl(ggml_tallocr_new_measure_from_backend(backend)); +} + +struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc) { + return ggml_tallocr_get_buffer(alloc->talloc); +} + +void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n) { + ggml_gallocr_set_parse_seq(alloc->galloc, list, n); +} + +void ggml_allocr_free(ggml_allocr_t alloc) { + ggml_gallocr_free(alloc->galloc); + ggml_tallocr_free(alloc->talloc); + free(alloc); +} + +bool ggml_allocr_is_measure(ggml_allocr_t alloc) { + return ggml_tallocr_is_measure(alloc->talloc); +} + +void ggml_allocr_reset(ggml_allocr_t alloc) { + ggml_tallocr_reset(alloc->talloc); +} + +void ggml_allocr_alloc(ggml_allocr_t alloc, struct ggml_tensor * tensor) { + ggml_tallocr_alloc(alloc->talloc, tensor); +} + +size_t ggml_allocr_max_size(ggml_allocr_t alloc) { + return ggml_tallocr_max_size(alloc->talloc); +} + +size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) { + return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph); +} + +// utils +ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { + GGML_ASSERT(ggml_get_no_alloc(ctx) == true); + + size_t alignment = ggml_backend_buft_get_alignment(buft); + + size_t nbytes = 0; + for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->data == NULL && t->view_src == NULL) { + nbytes += GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment); + } + } + + if (nbytes == 0) { + fprintf(stderr, "%s: no tensors to allocate\n", __func__); + return NULL; + } + + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes); + ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer); + + for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->data == NULL) { + if (t->view_src == NULL) { + ggml_tallocr_alloc(tallocr, t); + } else { + ggml_backend_view_init(buffer, t); + } + } + } + + ggml_tallocr_free(tallocr); + + return buffer; +} + +ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) { + return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend)); +} diff --git a/applications/src/llama/ggml-alloc.h b/applications/src/llama/ggml-alloc.h new file mode 100644 index 000000000..64a412468 --- /dev/null +++ b/applications/src/llama/ggml-alloc.h @@ -0,0 +1,92 @@ +#pragma once + +#include "ggml.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_backend; +struct ggml_backend_buffer; +struct ggml_backend_buffer_type; + +// +// Legacy API +// + +typedef struct ggml_allocr * ggml_allocr_t; + +// initialize allocator for use with CPU backend only +GGML_API ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment); +GGML_API ggml_allocr_t ggml_allocr_new_measure(size_t alignment); + +// initialize allocator for use with ggml-backend +GGML_API ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer); +GGML_API ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer +GGML_API ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend); + +GGML_API struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc); + +// tell the allocator to parse nodes following the order described in the list +// you should call this if your graph are optimized to execute out-of-order +GGML_API void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n); + +GGML_API void ggml_allocr_free (ggml_allocr_t alloc); +GGML_API bool ggml_allocr_is_measure (ggml_allocr_t alloc); +GGML_API void ggml_allocr_reset (ggml_allocr_t alloc); +GGML_API void ggml_allocr_alloc (ggml_allocr_t alloc, struct ggml_tensor * tensor); +GGML_API size_t ggml_allocr_max_size (ggml_allocr_t alloc); + +GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph); + +// +// ggml-backend v2 API +// + +// Separate tensor and graph allocator objects +// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators +// The original API is kept as a wrapper around the new API + +// Tensor allocator +typedef struct ggml_tallocr * ggml_tallocr_t; + +GGML_API ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment); +GGML_API ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment); +GGML_API ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer); +GGML_API ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer +GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend); + +GGML_API struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t talloc); + +GGML_API void ggml_tallocr_free (ggml_tallocr_t talloc); +GGML_API bool ggml_tallocr_is_measure (ggml_tallocr_t talloc); +GGML_API void ggml_tallocr_reset (ggml_tallocr_t talloc); +GGML_API void ggml_tallocr_alloc (ggml_tallocr_t talloc, struct ggml_tensor * tensor); +GGML_API size_t ggml_tallocr_max_size (ggml_tallocr_t talloc); + + +// Graph allocator +typedef struct ggml_gallocr * ggml_gallocr_t; + +GGML_API ggml_gallocr_t ggml_gallocr_new(void); +GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc); + +GGML_API void ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n); +GGML_API size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph); + +// Allocate tensors from the allocators given by the hash table +GGML_API void ggml_gallocr_alloc_graph_n( + ggml_gallocr_t galloc, + struct ggml_cgraph * graph, + struct ggml_hash_set hash_set, + ggml_tallocr_t * hash_node_talloc); + + +// Utils +// Create a buffer and allocate all the tensors in a ggml_context +GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft); +GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend); + +#ifdef __cplusplus +} +#endif diff --git a/applications/src/llama/ggml-backend-impl.h b/applications/src/llama/ggml-backend-impl.h new file mode 100644 index 000000000..f588af602 --- /dev/null +++ b/applications/src/llama/ggml-backend-impl.h @@ -0,0 +1,112 @@ +#pragma once + +// ggml-backend internal header + +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + + // + // Backend buffer + // + + // buffer type + typedef void * ggml_backend_buffer_type_context_t; + + struct ggml_backend_buffer_type_i { + ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size); + size_t (*get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment + size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding + bool (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend + }; + + struct ggml_backend_buffer_type { + struct ggml_backend_buffer_type_i iface; + ggml_backend_buffer_type_context_t context; + }; + + // buffer + typedef void * ggml_backend_buffer_context_t; + + struct ggml_backend_buffer_i { + void (*free_buffer)(ggml_backend_buffer_t buffer); + //void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras + void * (*get_base) (ggml_backend_buffer_t buffer); + void (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + // (optional) copy tensor between different buffer-type, allow for single-copy tranfers + void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst); + void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst); + }; + + struct ggml_backend_buffer { + struct ggml_backend_buffer_i iface; + ggml_backend_buffer_type_t buft; + ggml_backend_buffer_context_t context; + size_t size; + }; + + ggml_backend_buffer_t ggml_backend_buffer_init( + ggml_backend_buffer_type_t buft, + struct ggml_backend_buffer_i iface, + ggml_backend_buffer_context_t context, + size_t size); + + + // + // Backend + // + + typedef void * ggml_backend_context_t; + + struct ggml_backend_i { + const char * (*get_name)(ggml_backend_t backend); + + void (*free)(ggml_backend_t backend); + + // buffer allocation + ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend); + + // (optional) asynchroneous tensor data access + void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + + // (optional) asynchroneous tensor copy + void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); + void (*cpy_tensor_to_async) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); + + void (*synchronize) (ggml_backend_t backend); + + // compute graph with a plan + ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph); + void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan); + void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan); + + // compute graph without a plan + void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph); + + // check if the backend supports an operation + bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op); + }; + + struct ggml_backend { + struct ggml_backend_i iface; + + ggml_backend_context_t context; + }; + + + // + // Backend registry + // + + typedef ggml_backend_t (*ggml_backend_init_fn)(const char * params, void * user_data); + + void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data); + +#ifdef __cplusplus +} +#endif diff --git a/applications/src/llama/ggml-backend.c b/applications/src/llama/ggml-backend.c new file mode 100644 index 000000000..3a22cd085 --- /dev/null +++ b/applications/src/llama/ggml-backend.c @@ -0,0 +1,1357 @@ +#include "ggml-backend-impl.h" +#include "ggml-alloc.h" +#include "ggml-impl.h" + +#include +#include +#include +#include +#include +#include + + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + + +// backend buffer type + +ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + return buft->iface.alloc_buffer(buft, size); +} + +size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) { + return buft->iface.get_alignment(buft); +} + +size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) { + // get_alloc_size is optional, defaults to ggml_nbytes + if (buft->iface.get_alloc_size) { + return buft->iface.get_alloc_size(buft, tensor); + } + return ggml_nbytes(tensor); +} + +bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { + return buft->iface.supports_backend(buft, backend); +} + +// backend buffer + +ggml_backend_buffer_t ggml_backend_buffer_init( + ggml_backend_buffer_type_t buft, + struct ggml_backend_buffer_i iface, + ggml_backend_buffer_context_t context, + size_t size) { + ggml_backend_buffer_t buffer = malloc(sizeof(struct ggml_backend_buffer)); + + GGML_ASSERT(iface.get_base != NULL); + + (*buffer) = (struct ggml_backend_buffer) { + /* .interface = */ iface, + /* .buft = */ buft, + /* .context = */ context, + /* .size = */ size, + }; + + return buffer; +} + +void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { + if (buffer == NULL) { + return; + } + + if (buffer->iface.free_buffer != NULL) { + buffer->iface.free_buffer(buffer); + } + free(buffer); +} + +size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { + return buffer->size; +} + +void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { + void * base = buffer->iface.get_base(buffer); + + GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL"); + + return base; +} + +void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + // init_tensor is optional + if (buffer->iface.init_tensor) { + buffer->iface.init_tensor(buffer, tensor); + } +} + +size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer) { + return ggml_backend_buft_get_alignment(ggml_backend_buffer_type(buffer)); +} + +size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type(buffer), tensor); +} + +ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer) { + return buffer->buft; +} + +// backend + +const char * ggml_backend_name(ggml_backend_t backend) { + if (backend == NULL) { + return "NULL"; + } + return backend->iface.get_name(backend); +} + +void ggml_backend_free(ggml_backend_t backend) { + if (backend == NULL) { + return; + } + + backend->iface.free(backend); +} + +ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) { + return backend->iface.get_default_buffer_type(backend); +} + +ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) { + return ggml_backend_buft_alloc_buffer(ggml_backend_get_default_buffer_type(backend), size); +} + +size_t ggml_backend_get_alignment(ggml_backend_t backend) { + return ggml_backend_buft_get_alignment(ggml_backend_get_default_buffer_type(backend)); +} + +void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + + backend->iface.set_tensor_async(backend, tensor, data, offset, size); +} + +void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + + backend->iface.get_tensor_async(backend, tensor, data, offset, size); +} + +void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + + tensor->buffer->iface.set_tensor(tensor->buffer, tensor, data, offset, size); +} + +void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + + tensor->buffer->iface.get_tensor(tensor->buffer, tensor, data, offset, size); +} + +void ggml_backend_synchronize(ggml_backend_t backend) { + if (backend->iface.synchronize == NULL) { + return; + } + + backend->iface.synchronize(backend); +} + +ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + return backend->iface.graph_plan_create(backend, cgraph); +} + +void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + backend->iface.graph_plan_free(backend, plan); +} + +void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + backend->iface.graph_plan_compute(backend, plan); + + // TODO: optional sync + ggml_backend_synchronize(backend); +} + +void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + backend->iface.graph_compute(backend, cgraph); + + // TODO: optional sync + ggml_backend_synchronize(backend); +} + +bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { + return backend->iface.supports_op(backend, op); +} + +// backend copy + +static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { + if (a->type != b->type) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (a->ne[i] != b->ne[i]) { + return false; + } + if (a->nb[i] != b->nb[i]) { + return false; + } + } + return true; +} + +void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) { + //printf("src: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", src->name, (int)src->ne[0], (int)src->ne[1], (int)src->ne[2], (int)src->ne[3], (int)src->nb[0], (int)src->nb[1], (int)src->nb[2], (int)src->nb[3]); + //printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]); + GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); + + // fprintf(stderr, "cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src)); + + if (src == dst) { + return; + } + + // TODO: allow backends to support copy to/from same backend + + if (dst->buffer->iface.cpy_tensor_from != NULL) { + dst->buffer->iface.cpy_tensor_from(dst->buffer, src, dst); + } else if (src->buffer->iface.cpy_tensor_to != NULL) { + src->buffer->iface.cpy_tensor_to(src->buffer, src, dst); + } else { + // shouldn't be hit when copying from/to CPU + #ifndef NDEBUG + fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to " + "are implemented for %s and %s, falling back to get/set\n", src->name, dst->name); + #endif + size_t nbytes = ggml_nbytes(src); + void * data = malloc(nbytes); + ggml_backend_tensor_get(src, data, 0, nbytes); + ggml_backend_tensor_set(dst, data, 0, nbytes); + free(data); + } +} + +// backend registry + +#define GGML_MAX_BACKENDS_REG 16 + +struct ggml_backend_reg { + char name[128]; + ggml_backend_init_fn init_fn; + ggml_backend_buffer_type_t default_buffer_type; + void * user_data; +}; + +static struct ggml_backend_reg ggml_backend_registry[GGML_MAX_BACKENDS_REG]; +static size_t ggml_backend_registry_count = 0; + +static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data); + +static void ggml_backend_registry_init(void) { + static bool initialized = false; + + if (initialized) { + return; + } + + initialized = true; + + ggml_backend_register("CPU", ggml_backend_reg_cpu_init, ggml_backend_cpu_buffer_type(), NULL); + + // add forward decls here to avoid including the backend headers +#ifdef GGML_USE_CUBLAS + extern void ggml_backend_cuda_reg_devices(void); + ggml_backend_cuda_reg_devices(); +#endif + +#ifdef GGML_USE_METAL + extern ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); + extern ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); + ggml_backend_register("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL); +#endif +} + +void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) { + GGML_ASSERT(ggml_backend_registry_count < GGML_MAX_BACKENDS_REG); + + int id = ggml_backend_registry_count; + + ggml_backend_registry[id] = (struct ggml_backend_reg) { + /* .name = */ {0}, + /* .fn = */ init_fn, + /* .default_buffer_type = */ default_buffer_type, + /* .user_data = */ user_data, + }; + + snprintf(ggml_backend_registry[id].name, sizeof(ggml_backend_registry[id].name), "%s", name); + +#ifndef NDEBUG + fprintf(stderr, "%s: registered backend %s\n", __func__, name); +#endif + + ggml_backend_registry_count++; +} + +size_t ggml_backend_reg_get_count(void) { + ggml_backend_registry_init(); + + return ggml_backend_registry_count; +} + +size_t ggml_backend_reg_find_by_name(const char * name) { + ggml_backend_registry_init(); + + for (size_t i = 0; i < ggml_backend_registry_count; i++) { + // TODO: case insensitive in a portable way + if (strcmp(ggml_backend_registry[i].name, name) == 0) { + return i; + } + } + return SIZE_MAX; +} + +// init from backend:params string +ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str) { + ggml_backend_registry_init(); + + const char * params = strchr(backend_str, ':'); + char backend_name[128]; + if (params == NULL) { + strcpy(backend_name, backend_str); + params = ""; + } else { + strncpy(backend_name, backend_str, params - backend_str); + backend_name[params - backend_str] = '\0'; + params++; + } + + size_t backend_i = ggml_backend_reg_find_by_name(backend_name); + if (backend_i == SIZE_MAX) { + fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name); + return NULL; + } + + return ggml_backend_reg_init_backend(backend_i, params); +} + +const char * ggml_backend_reg_get_name(size_t i) { + ggml_backend_registry_init(); + + GGML_ASSERT(i < ggml_backend_registry_count); + return ggml_backend_registry[i].name; +} + +ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params) { + ggml_backend_registry_init(); + + GGML_ASSERT(i < ggml_backend_registry_count); + return ggml_backend_registry[i].init_fn(params, ggml_backend_registry[i].user_data); +} + +ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i) { + ggml_backend_registry_init(); + + GGML_ASSERT(i < ggml_backend_registry_count); + return ggml_backend_registry[i].default_buffer_type; +} + +ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size) { + ggml_backend_registry_init(); + + GGML_ASSERT(i < ggml_backend_registry_count); + return ggml_backend_buft_alloc_buffer(ggml_backend_registry[i].default_buffer_type, size); +} + +// backend CPU + +static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { + return (void *)buffer->context; +} + +static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { + free(buffer->context); + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + memcpy((char *)tensor->data + offset, data, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + memcpy(data, (const char *)tensor->data + offset, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) { + ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src)); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) { + ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src)); + + GGML_UNUSED(buffer); +} + +static struct ggml_backend_buffer_i cpu_backend_buffer_i = { + /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer, + /* .get_base = */ ggml_backend_cpu_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from, + /* .cpy_tensor_to = */ ggml_backend_cpu_buffer_cpy_tensor_to, +}; + +// for buffers from ptr, free is not called +static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = { + /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed + /* .get_base = */ ggml_backend_cpu_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from, + /* .cpy_tensor_to = */ ggml_backend_cpu_buffer_cpy_tensor_to, +}; + +static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512 + +static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned + void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC? + + GGML_ASSERT(data != NULL && "failed to allocate buffer"); + + return ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size); +} + +static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; + + GGML_UNUSED(buft); +} + +static bool ggml_backend_cpu_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { + return ggml_backend_is_cpu(backend); + + GGML_UNUSED(buft); +} + +ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_buffer_type_cpu = { + /* .iface = */ { + /* .alloc_buffer = */ ggml_backend_cpu_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend, + }, + /* .context = */ NULL, + }; + + return &ggml_backend_buffer_type_cpu; +} + +struct ggml_backend_cpu_context { + int n_threads; + void * work_data; + size_t work_size; +}; + +static const char * ggml_backend_cpu_name(ggml_backend_t backend) { + return "CPU"; + + GGML_UNUSED(backend); +} + +static void ggml_backend_cpu_free(ggml_backend_t backend) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + free(cpu_ctx->work_data); + free(cpu_ctx); + free(backend); +} + +static ggml_backend_buffer_type_t ggml_backend_cpu_get_default_buffer_type(ggml_backend_t backend) { + return ggml_backend_cpu_buffer_type(); + + GGML_UNUSED(backend); +} + +struct ggml_backend_plan_cpu { + struct ggml_cplan cplan; + struct ggml_cgraph cgraph; +}; + +static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + + struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu)); + + cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); + cpu_plan->cgraph = *cgraph; + + if (cpu_plan->cplan.work_size > 0) { + cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); + } + + return cpu_plan; +} + +static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; + + free(cpu_plan->cplan.work_data); + free(cpu_plan); + + GGML_UNUSED(backend); +} + +static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; + + ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); + + GGML_UNUSED(backend); +} + +static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + + struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); + + if (cpu_ctx->work_size < cplan.work_size) { + // TODO: may be faster to free and use malloc to avoid the copy + cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size); + cpu_ctx->work_size = cplan.work_size; + } + + cplan.work_data = cpu_ctx->work_data; + + ggml_graph_compute(cgraph, &cplan); +} + +static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { + return true; + + GGML_UNUSED(backend); + GGML_UNUSED(op); +} + +static struct ggml_backend_i cpu_backend_i = { + /* .get_name = */ ggml_backend_cpu_name, + /* .free = */ ggml_backend_cpu_free, + /* .get_default_buffer_type = */ ggml_backend_cpu_get_default_buffer_type, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_from_async = */ NULL, + /* .cpy_tensor_to_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, + /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free, + /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute, + /* .graph_compute = */ ggml_backend_cpu_graph_compute, + /* .supports_op = */ ggml_backend_cpu_supports_op, +}; + +ggml_backend_t ggml_backend_cpu_init(void) { + struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context)); + + ctx->n_threads = GGML_DEFAULT_N_THREADS; + ctx->work_data = NULL; + ctx->work_size = 0; + + ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend)); + + *cpu_backend = (struct ggml_backend) { + /* .interface = */ cpu_backend_i, + /* .context = */ ctx + }; + return cpu_backend; +} + +bool ggml_backend_is_cpu(ggml_backend_t backend) { + return backend->iface.get_name == ggml_backend_cpu_name; +} + +void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->n_threads = n_threads; +} + +ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) { + return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size); +} + +static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data) { + return ggml_backend_cpu_init(); + + GGML_UNUSED(params); + GGML_UNUSED(user_data); +} + + +// scheduler + +#define GGML_MAX_BACKENDS 4 +#define GGML_MAX_SPLITS 256 +#define GGML_MAX_SPLIT_INPUTS 16 + +struct ggml_backend_sched_split { + ggml_tallocr_t tallocr; + int i_start; + int i_end; + struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS]; + int n_inputs; + struct ggml_cgraph graph; +}; + +struct ggml_backend_sched { + int n_backends; + ggml_backend_t backends[GGML_MAX_BACKENDS]; + ggml_tallocr_t tallocs[GGML_MAX_BACKENDS]; + + ggml_gallocr_t galloc; + + struct ggml_hash_set hash_set; + ggml_tallocr_t * node_talloc; // [hash_set.size] + struct ggml_tensor * (* node_copies)[GGML_MAX_BACKENDS]; // [hash_set.size][GGML_MAX_BACKENDS] + + struct ggml_cgraph * graph; + struct ggml_backend_sched_split splits[GGML_MAX_SPLITS]; + int n_splits; + + struct ggml_context * ctx; + + // align context_buffer to GGML_MEM_ALIGN + #ifdef _MSC_VER + __declspec(align(GGML_MEM_ALIGN)) + #else + __attribute__((aligned(GGML_MEM_ALIGN))) + #endif + char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)]; +}; + +#define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node) +#define node_allocr(node) sched->node_talloc[hash_id(node)] + +static bool ggml_is_view_op(enum ggml_op op) { + return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE; +} + +// returns the priority of the backend, lower is better +static int sched_backend_prio(ggml_backend_sched_t sched, ggml_backend_t backend) { + for (int i = 0; i < sched->n_backends; i++) { + if (sched->backends[i] == backend) { + return i; + } + } + return INT_MAX; +} + +static int sched_allocr_prio(ggml_backend_sched_t sched, ggml_tallocr_t allocr) { + for (int i = 0; i < sched->n_backends; i++) { + if (sched->tallocs[i] == allocr) { + return i; + } + } + return INT_MAX; +} + +static ggml_backend_t get_buffer_backend(ggml_backend_sched_t sched, ggml_backend_buffer_t buffer) { + if (buffer == NULL) { + return NULL; + } + // find highest prio backend that supports the buffer type + for (int i = 0; i < sched->n_backends; i++) { + if (ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) { + return sched->backends[i]; + } + } + GGML_ASSERT(false && "tensor buffer type not supported by any backend"); +} + +static ggml_backend_t get_allocr_backend(ggml_backend_sched_t sched, ggml_tallocr_t allocr) { + if (allocr == NULL) { + return NULL; + } + // find highest prio backend that supports the buffer type + for (int i = 0; i < sched->n_backends; i++) { + if (sched->tallocs[i] == allocr) { + return sched->backends[i]; + } + } + GGML_UNREACHABLE(); +} + +#if 0 +static char causes[GGML_DEFAULT_GRAPH_SIZE*8 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove +#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__) +#define GET_CAUSE(node) causes[hash_id(node)] +#else +#define SET_CAUSE(node, ...) +#define GET_CAUSE(node) "" +#endif + +// returns the backend that should be used for the node based on the current locations +static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * node) { + // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there + // ie. kv cache updates + // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend. + // dst + ggml_backend_t cur_backend = get_buffer_backend(sched, node->buffer); + if (cur_backend != NULL) { + SET_CAUSE(node, "1.dst"); + return cur_backend; + } + + // view_src + if (node->view_src != NULL && get_buffer_backend(sched, node->view_src->buffer) != NULL) { + SET_CAUSE(node, "1.vsrc"); + return get_buffer_backend(sched, node->view_src->buffer); + } + + // src + int cur_prio = INT_MAX; + size_t cur_size = 0; + + for (int i = 0; i < GGML_MAX_SRC; i++) { + const struct ggml_tensor * src = node->src[i]; + if (src == NULL) { + break; + } + ggml_backend_t src_backend = get_buffer_backend(sched, src->buffer); + if (src_backend != NULL) { + int src_prio = sched_backend_prio(sched, src_backend); + size_t src_size = ggml_nbytes(src); + if (src_prio < cur_prio && src_size >= cur_size) { + cur_prio = src_prio; + cur_size = src_size; + cur_backend = src_backend; + SET_CAUSE(node, "1.src%d", i); + } + } + } + return cur_backend; +} + +static char * fmt_size(size_t size) { + static char buffer[128]; + if (size >= 1024*1024) { + sprintf(buffer, "%zuM", size/1024/1024); + } else { + sprintf(buffer, "%zuK", size/1024); + } + return buffer; +} + +static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + int cur_split = 0; + for (int i = 0; i < graph->n_nodes; i++) { + if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) { + ggml_backend_t split_backend = get_allocr_backend(sched, sched->splits[cur_split].tallocr); + fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), + sched->splits[cur_split].n_inputs); + for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) { + fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, + fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j]))); + } + fprintf(stderr, "\n"); + cur_split++; + } + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + ggml_tallocr_t node_allocr = node_allocr(node); + ggml_backend_t node_backend = node_allocr ? get_allocr_backend(sched, node_allocr) : NULL; // FIXME: + fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name, + fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", GET_CAUSE(node)); + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + break; + } + ggml_tallocr_t src_allocr = node_allocr(src); + ggml_backend_t src_backend = src_allocr ? get_allocr_backend(sched, src_allocr) : NULL; + fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name, + fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src)); + } + fprintf(stderr, "\n"); + } +} + +// creates a copy of the tensor with the same memory layout +static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) { + struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor); + for (int i = 0; i < GGML_MAX_DIMS; i++) { + dup->nb[i] = tensor->nb[i]; + } + return dup; +} + +// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend +// TODO: merge passes +static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + // reset state + size_t hash_size = sched->hash_set.size; + memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); + memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size); + memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size); + sched->n_splits = 0; + + struct ggml_init_params params = { + /* .mem_size = */ sizeof(sched->context_buffer), + /* .mem_buffer = */ sched->context_buffer, + /* .no_alloc = */ true + }; + + if (sched->ctx != NULL) { + ggml_free(sched->ctx); + } + + sched->ctx = ggml_init(params); + + // pass 1: assign backends to ops with allocated inputs + for (int i = 0; i < graph->n_leafs; i++) { + struct ggml_tensor * leaf = graph->leafs[i]; + if (node_allocr(leaf) != NULL) { + // do not overwrite user assignments + continue; + } + ggml_backend_t leaf_backend = get_buffer_backend(sched, leaf->buffer); + if (leaf_backend == NULL && leaf->view_src != NULL) { + leaf_backend = get_buffer_backend(sched, leaf->view_src->buffer); + } + if (leaf_backend != NULL) { + node_allocr(leaf) = ggml_backend_sched_get_tallocr(sched, leaf_backend); + } + } + + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + if (node_allocr(node) != NULL) { + // do not overwrite user assignments + continue; + } + ggml_backend_t node_backend = sched_backend_from_cur(sched, node); + if (node_backend != NULL) { + node_allocr(node) = ggml_backend_sched_get_tallocr(sched, node_backend); + } + } + //printf("PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); + + // pass 2: assign backends to ops from current assignments + // TODO: + // - reuse sched_backend_from_cur + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + ggml_tallocr_t node_allocr = node_allocr(node); + if (node_allocr == NULL) { + int cur_prio = INT_MAX; + size_t cur_size = 0; + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + break; + } + ggml_tallocr_t src_allocr = node_allocr(src); + if (src_allocr != NULL) { + int src_prio = sched_allocr_prio(sched, src_allocr); + size_t src_size = ggml_nbytes(src); + if (src_prio < cur_prio && src_size >= cur_size) { + cur_prio = src_prio; + cur_size = src_size; + node_allocr = src_allocr; + SET_CAUSE(node, "2.src%d", j); + } + } + } + if (node_allocr != NULL) { + node_allocr(node) = node_allocr; + } + } + } + //printf("PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); + + // pass 3: assign backends to remaining src from dst (should only be leafs) + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + ggml_tallocr_t node_allocr = node_allocr(node); + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + break; + } + ggml_tallocr_t src_allocr = node_allocr(src); + if (src_allocr == NULL) { + node_allocr(src) = node_allocr; + } + } + } + //printf("PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); + + // pass 4: split graph, find tensors that need to be copied + // TODO: + // - when switching from a less preferred backend to a more preferred backend, check if it is possible to move the switch to an earlier point for the same cost + // find first backend + int cur_split = 0; + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + if (node->view_src == NULL) { + sched->splits[0].tallocr = node_allocr(node); + break; + } + } + sched->splits[0].i_start = 0; + sched->splits[0].n_inputs = 0; + memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK + ggml_tallocr_t cur_allocr = sched->splits[0].tallocr; + size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr); + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + + if (ggml_is_view_op(node->op)) { + continue; + } + + ggml_tallocr_t node_allocr = node_allocr(node); + + if (node_allocr != cur_allocr) { + sched->splits[cur_split].i_end = i; + cur_split++; + GGML_ASSERT(cur_split < GGML_MAX_SPLITS); + sched->splits[cur_split].tallocr = node_allocr; + sched->splits[cur_split].i_start = i; + sched->splits[cur_split].n_inputs = 0; + memset(sched->splits[cur_split].inputs, 0, sizeof(sched->splits[cur_split].inputs)); //HACK + cur_allocr = node_allocr; + cur_backend_id = sched_allocr_prio(sched, cur_allocr); + } + + // find inputs that are not on the same backend + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + break; + } + ggml_tallocr_t src_allocr = node_allocr(src); + if (src_allocr != node_allocr) { + int n_inputs = sched->splits[cur_split].n_inputs++; + GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS); + sched->splits[cur_split].inputs[n_inputs] = (struct ggml_tensor *)src; + + // create copies + size_t id = hash_id(src); + if (sched->node_copies[id][cur_backend_id] == NULL) { + struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src); + sched->node_copies[id][cur_backend_id] = tensor_copy; + node_allocr(tensor_copy) = cur_allocr; + ggml_backend_t backend = get_allocr_backend(sched, cur_allocr); + ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name); + } + node->src[j] = sched->node_copies[id][cur_backend_id]; + } + } + } + sched->splits[cur_split].i_end = graph->n_nodes; + sched->n_splits = cur_split + 1; + + //fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); fflush(stdout); + +#if 1 + // sanity check: all sources should have the same backend as the node + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + ggml_tallocr_t node_allocr = node_allocr(node); + if (node_allocr == NULL) { + fprintf(stderr, "!!!!!!! %s has no backend\n", node->name); + } + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + break; + } + ggml_tallocr_t src_allocr = node_allocr(src); + if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now + fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n", + node->name, node_allocr ? ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL", + j, src->name, src_allocr ? ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL"); + } + } + } +#endif + + // create copies of the graph for each split + // FIXME: avoid this copy, pass split inputs to ggml_gallocr_alloc_graph_n in some other way + struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_MAX_SPLIT_INPUTS, false); + for (int i = 0; i < sched->n_splits; i++) { + struct ggml_backend_sched_split * split = &sched->splits[i]; + split->graph = ggml_graph_view(graph, split->i_start, split->i_end); + + // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split + for (int j = 0; j < split->n_inputs; j++) { + struct ggml_tensor * input = split->inputs[j]; + struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)]; + input_cpy->src[0] = input; + graph_copy->nodes[graph_copy->n_nodes++] = input_cpy; + } + + for (int j = split->i_start; j < split->i_end; j++) { + graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j]; + } + } + sched->graph = graph_copy; +} + +static void sched_alloc_splits(ggml_backend_sched_t sched) { + ggml_gallocr_alloc_graph_n( + sched->galloc, + sched->graph, + sched->hash_set, + sched->node_talloc); +} + +static void sched_compute_splits(ggml_backend_sched_t sched) { + uint64_t copy_us[GGML_MAX_BACKENDS] = {0}; + uint64_t compute_us[GGML_MAX_BACKENDS] = {0}; + + struct ggml_backend_sched_split * splits = sched->splits; + + for (int i = 0; i < sched->n_splits; i++) { + struct ggml_backend_sched_split * split = &splits[i]; + ggml_backend_t split_backend = get_allocr_backend(sched, split->tallocr); + int split_backend_id = sched_backend_prio(sched, split_backend); + + // copy the input tensors to the split backend + uint64_t copy_start_us = ggml_time_us(); + for (int j = 0; j < split->n_inputs; j++) { + struct ggml_tensor * input = split->inputs[j]; + struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_backend_prio(sched, split_backend)]; + if (input->buffer == NULL) { + if (input->view_src == NULL) { + fprintf(stderr, "input %s has no buffer and no view_src\n", input->name); + exit(1); + } + // FIXME: may need to use the sched buffer instead + ggml_backend_view_init(input->view_src->buffer, input); + } + if (input_cpy->buffer == NULL) { + fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name); + exit(1); + } + //GGML_ASSERT(input->buffer->backend != input_cpy->buffer->backend); + //GGML_ASSERT(input_cpy->buffer->backend == split_backend); + ggml_backend_tensor_copy(input, input_cpy); + } + // ggml_backend_synchronize(split_backend); + int64_t copy_end_us = ggml_time_us(); + copy_us[split_backend_id] += copy_end_us - copy_start_us; + +#if 0 + char split_filename[GGML_MAX_NAME]; + snprintf(split_filename, GGML_MAX_NAME, "split_%i_%s.dot", i, ggml_backend_name(split_backend)); + ggml_graph_dump_dot(split->graph, NULL, split_filename); +#endif + + uint64_t compute_start_us = ggml_time_us(); + ggml_backend_graph_compute(split_backend, &split->graph); + // ggml_backend_synchronize(split_backend); + uint64_t compute_end_us = ggml_time_us(); + compute_us[split_backend_id] += compute_end_us - compute_start_us; + } + +#if 0 + // per-backend timings + fprintf(stderr, "sched_compute_splits times (%d splits):\n", sched->n_splits); + for (int i = 0; i < sched->n_backends; i++) { + if (copy_us[i] > 0 || compute_us[i] > 0) { + fprintf(stderr, "\t%5.5s: %lu us copy, %lu us compute\n", ggml_backend_name(sched->backends[i]), copy_us[i], compute_us[i]); + } + } +#endif +} + +static void sched_reset(ggml_backend_sched_t sched) { + for (int i = 0; i < sched->n_backends; i++) { + ggml_tallocr_reset(sched->tallocs[i]); + } +} + +ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends) { + GGML_ASSERT(n_backends <= GGML_MAX_BACKENDS); + + struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched)); + memset(sched, 0, sizeof(struct ggml_backend_sched)); + + sched->n_backends = n_backends; + for (int i = 0; i < n_backends; i++) { + sched->backends[i] = backends[i]; + } + + sched->galloc = ggml_gallocr_new(); + + // init measure allocs for each backend + for (int i = 0; i < n_backends; i++) { + sched->tallocs[i] = ggml_tallocr_new_measure_from_backend(backends[i]); + } + + return sched; +} + +void ggml_backend_sched_free(ggml_backend_sched_t sched) { + if (sched == NULL) { + return; + } + for (int i = 0; i < sched->n_backends; i++) { + ggml_tallocr_free(sched->tallocs[i]); + } + ggml_gallocr_free(sched->galloc); + free(sched->hash_set.keys); + free(sched->node_talloc); + free(sched->node_copies); + free(sched); +} + +void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) { + // initialize hash tables + size_t hash_size = measure_graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS; + sched->hash_set.size = hash_size; + sched->hash_set.keys = malloc(sizeof(sched->hash_set.keys[0]) * hash_size); + sched->node_talloc = malloc(sizeof(sched->node_talloc[0]) * hash_size); + sched->node_copies = malloc(sizeof(sched->node_copies[0]) * hash_size); + + sched_split_graph(sched, measure_graph); + sched_alloc_splits(sched); + + // allocate buffers and reset allocators + for (int i = 0; i < sched->n_backends; i++) { + size_t size = ggml_tallocr_max_size(sched->tallocs[i]); + ggml_tallocr_free(sched->tallocs[i]); + sched->tallocs[i] = ggml_tallocr_new_from_backend(sched->backends[i], size); + } + + sched_reset(sched); +} + +void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + GGML_ASSERT(sched->hash_set.size >= graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS); + + sched_split_graph(sched, graph); + sched_alloc_splits(sched); + sched_compute_splits(sched); + sched_reset(sched); +} + +ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend) { + int backend_index = sched_backend_prio(sched, backend); + return sched->tallocs[backend_index]; +} + +ggml_backend_buffer_t ggml_backend_sched_get_buffer(ggml_backend_sched_t sched, ggml_backend_t backend) { + int backend_index = sched_backend_prio(sched, backend); + return ggml_tallocr_get_buffer(sched->tallocs[backend_index]); +} + +void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { + int backend_index = sched_backend_prio(sched, backend); + GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); + node_allocr(node) = sched->tallocs[backend_index]; +} + +// utils +void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->buffer == NULL); + GGML_ASSERT(tensor->data == NULL); + GGML_ASSERT(tensor->view_src != NULL); + GGML_ASSERT(tensor->view_src->buffer != NULL); + GGML_ASSERT(tensor->view_src->data != NULL); + + tensor->buffer = buffer; + tensor->data = (char *)tensor->view_src->data + tensor->view_offs; + tensor->backend = tensor->view_src->backend; + ggml_backend_buffer_init_tensor(buffer, tensor); +} + +void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) { + GGML_ASSERT(tensor->buffer == NULL); + GGML_ASSERT(tensor->data == NULL); + GGML_ASSERT(tensor->view_src == NULL); + GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer)); + GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <= + (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer)); + + tensor->buffer = buffer; + tensor->data = addr; + ggml_backend_buffer_init_tensor(buffer, tensor); +} + +static struct ggml_tensor * graph_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, + struct ggml_context * ctx_allocated, struct ggml_context * ctx_unallocated, struct ggml_tensor * src) { + + GGML_ASSERT(src != NULL); + GGML_ASSERT(src->data && "graph must be allocated"); + + size_t id = ggml_hash_insert(hash_set, src); + if (id == GGML_HASHTABLE_ALREADY_EXISTS) { + return node_copies[ggml_hash_find(hash_set, src)]; + } + + struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src); + if (src->view_src != NULL) { + dst->view_src = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src); + dst->view_offs = src->view_offs; + } + dst->op = src->op; + memcpy(dst->op_params, src->op_params, sizeof(dst->op_params)); + ggml_set_name(dst, src->name); + + // copy src + for (int i = 0; i < GGML_MAX_SRC; i++) { + struct ggml_tensor * s = src->src[i]; + if (s == NULL) { + break; + } + dst->src[i] = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s); + } + + node_copies[id] = dst; + return dst; +} + +static void graph_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) { + size_t id = ggml_hash_find(hash_set, src); + if (node_init[id]) { + return; + } + node_init[id] = true; + + struct ggml_tensor * dst = node_copies[id]; + if (dst->view_src != NULL) { + ggml_backend_view_init(dst->view_src->buffer, dst); + } + else { + ggml_backend_tensor_copy(src, dst); + } + + // init src + for (int i = 0; i < GGML_MAX_SRC; i++) { + struct ggml_tensor * s = src->src[i]; + if (s == NULL) { + break; + } + graph_init_tensor(hash_set, node_copies, node_init, s); + } +} + +struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) { + struct ggml_hash_set hash_set = { + /* .size = */ graph->visited_hash_table.size, + /* .keys = */ calloc(sizeof(hash_set.keys[0]) * graph->visited_hash_table.size, 1) + }; + struct ggml_tensor ** node_copies = calloc(sizeof(node_copies[0]) * hash_set.size, 1); + bool * node_init = calloc(sizeof(node_init[0]) * hash_set.size, 1); + + struct ggml_init_params params = { + /* .mem_size = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false), + /* .mem_buffer = */ NULL, + /* .no_alloc = */ true + }; + + struct ggml_context * ctx_allocated = ggml_init(params); + struct ggml_context * ctx_unallocated = ggml_init(params); + + // dup nodes + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node); + } + + // allocate nodes + ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend); + + //printf("copy buffer size: %zu MB\n", ggml_backend_buffer_get_size(buffer) / 1024 / 1024); + + // copy data and init views + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + graph_init_tensor(hash_set, node_copies, node_init, node); + } + + // build graph copy + struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false); + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + struct ggml_tensor * node_copy = node_copies[ggml_hash_find(hash_set, node)]; + graph_copy->nodes[i] = node_copy; + } + graph_copy->n_nodes = graph->n_nodes; + + free(hash_set.keys); + free(node_copies); + free(node_init); + + return (struct ggml_backend_graph_copy) { + /* .buffer = */ buffer, + /* .ctx_allocated = */ ctx_allocated, + /* .ctx_unallocated = */ ctx_unallocated, + /* .graph = */ graph_copy, + }; +} + +void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) { + ggml_backend_buffer_free(copy.buffer); + ggml_free(copy.ctx_allocated); + ggml_free(copy.ctx_unallocated); +} + +void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) { + struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph); + struct ggml_cgraph * g1 = graph; + struct ggml_cgraph * g2 = copy.graph; + + assert(g1->n_nodes == g2->n_nodes); + + for (int i = 0; i < g1->n_nodes; i++) { + //printf("eval %d/%d\n", i, g1->n_nodes); + struct ggml_tensor * t1 = g1->nodes[i]; + struct ggml_tensor * t2 = g2->nodes[i]; + + assert(t1->op == t2->op && ggml_are_same_layout(t1, t2)); + + struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1); + struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1); + + ggml_backend_graph_compute(backend1, &g1v); + ggml_backend_graph_compute(backend2, &g2v); + + if (ggml_is_view_op(t1->op)) { + continue; + } + + // compare results, calculate rms etc + if (!callback(i, t1, t2, user_data)) { + break; + } + } + + ggml_backend_graph_copy_free(copy); +} diff --git a/applications/src/llama/ggml-backend.h b/applications/src/llama/ggml-backend.h new file mode 100644 index 000000000..58d5ccae6 --- /dev/null +++ b/applications/src/llama/ggml-backend.h @@ -0,0 +1,181 @@ +#pragma once + +#include "ggml.h" +#include "ggml-alloc.h" + +#ifdef __cplusplus +extern "C" { +#endif + + typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t; + typedef struct ggml_backend_buffer * ggml_backend_buffer_t; + typedef struct ggml_backend * ggml_backend_t; + typedef void * ggml_backend_graph_plan_t; + + // + // Backend buffer + // + + // buffer type + GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); + GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft); + GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); + GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend); + + // buffer + GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); + GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer); + + // + // Backend + // + + + GGML_API const char * ggml_backend_name(ggml_backend_t backend); + GGML_API void ggml_backend_free(ggml_backend_t backend); + + GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend); + GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size); + GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); + + GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + + GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + + GGML_API void ggml_backend_synchronize(ggml_backend_t backend); + + GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph); + + GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan); + GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan); + GGML_API void ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph); + GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op); + + // tensor copy between different backends + GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); + GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); // automatic fallback to sync copy + + // + // CPU backend + // + + GGML_API ggml_backend_t ggml_backend_cpu_init(void); + + GGML_API bool ggml_backend_is_cpu(ggml_backend_t backend); + GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads); + + // Create a backend buffer from an existing pointer + GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); + + GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void); + + // + // Backend registry + // + + // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way + + GGML_API size_t ggml_backend_reg_get_count(void); + GGML_API size_t ggml_backend_reg_find_by_name(const char * name); + GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params] + GGML_API const char * ggml_backend_reg_get_name(size_t i); + GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific + GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i); + GGML_API ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size); + + // + // Backend scheduler + // + + // The backend scheduler allows for multiple backends to be used together + // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends + // The backends are selected based on: + // - the backend that supports the operation + // - the location of the pre-allocated tensors (e.g. the weights) + /* + Example usage: + + sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, num_backends); + // sched is initialized with measure allocators and cannot be used until allocated with a measure graph + + // initialize buffers from a measure graph + measure_graph = build_graph(sched); // use the allocr to allocate inputs as needed + + // in build_graph: + build_graph(...) { + // allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer) + alloc_cpu = ggml_backend_sched_get_allocr(sched, backend_cpu); + ggml_allocr_alloc(alloc_cpu, tensor); + + // manually assigning nodes to a backend (optional, shouldn't be needed in most cases) + struct ggml_tensor * node = ggml_mul_mat(ctx, ...); + ggml_backend_sched_set_node_backend(sched, node, backend_gpu); + } + + // allocate backend buffers from measure graph + ggml_backend_sched_init_measure(sched, measure_graph); + + // the scheduler is now ready to compute graphs + + // compute + graph = build_graph(sched); + ggml_backend_sched_graph_compute(sched, graph); + */ + + struct ggml_backend_sched; + typedef struct ggml_backend_sched * ggml_backend_sched_t; + + // Initialize a backend scheduler + GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends); + + GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); + + // Initialize backend buffers from a measure graph + GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); + + GGML_API ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend); + GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend); + + GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); + + // Allocate a graph on the backend scheduler + GGML_API void ggml_backend_sched_graph_compute( + ggml_backend_sched_t sched, + struct ggml_cgraph * graph); + + + // + // Utils + // + + struct ggml_backend_graph_copy { + ggml_backend_buffer_t buffer; + struct ggml_context * ctx_allocated; + struct ggml_context * ctx_unallocated; + struct ggml_cgraph * graph; + }; + + // Copy a graph to a different backend + GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph); + GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy); + + typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data); + + // Compare the output of two backends + GGML_API void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data); + + // Tensor initialization + GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); + GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + + +#ifdef __cplusplus +} +#endif diff --git a/applications/src/llama/ggml-cuda.cu b/applications/src/llama/ggml-cuda.cu new file mode 100644 index 000000000..0a63c1ecf --- /dev/null +++ b/applications/src/llama/ggml-cuda.cu @@ -0,0 +1,9874 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#if defined(GGML_USE_HIPBLAS) +#include +#include +#include +#ifdef __HIP_PLATFORM_AMD__ +// for rocblas_initialize() +#include "rocblas/rocblas.h" +#endif // __HIP_PLATFORM_AMD__ +#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F +#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F +#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N HIPBLAS_OP_N +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH 0 +#define CUDA_R_16F HIPBLAS_R_16F +#define CUDA_R_32F HIPBLAS_R_32F +#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +#define cublasCreate hipblasCreate +#define cublasGemmEx hipblasGemmEx +#define cublasGemmBatchedEx hipblasGemmBatchedEx +#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx +#define cublasHandle_t hipblasHandle_t +#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS +#define cublasSetStream hipblasSetStream +#define cublasSgemm hipblasSgemm +#define cublasStatus_t hipblasStatus_t +#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess +#define cudaDeviceProp hipDeviceProp_t +#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaError_t hipError_t +#define cudaEventCreateWithFlags hipEventCreateWithFlags +#define cudaEventDisableTiming hipEventDisableTiming +#define cudaEventRecord hipEventRecord +#define cudaEvent_t hipEvent_t +#define cudaEventDestroy hipEventDestroy +#define cudaFree hipFree +#define cudaFreeHost hipHostFree +#define cudaGetDevice hipGetDevice +#define cudaGetDeviceCount hipGetDeviceCount +#define cudaGetDeviceProperties hipGetDeviceProperties +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaMalloc hipMalloc +#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) +#define cudaMemcpy hipMemcpy +#define cudaMemcpy2DAsync hipMemcpy2DAsync +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyKind hipMemcpyKind +#define cudaMemset hipMemset +#define cudaMemsetAsync hipMemsetAsync +#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize +#define cudaSetDevice hipSetDevice +#define cudaStreamCreateWithFlags hipStreamCreateWithFlags +#define cudaStreamFireAndForget hipStreamFireAndForget +#define cudaStreamNonBlocking hipStreamNonBlocking +#define cudaStreamSynchronize hipStreamSynchronize +#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) +#define cudaStream_t hipStream_t +#define cudaSuccess hipSuccess +#else +#include +#include +#include +#endif // defined(GGML_USE_HIPBLAS) + +#include "ggml-cuda.h" +#include "ggml.h" +#include "ggml-backend-impl.h" + +#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#define CC_VOLTA 700 +#define CC_OFFSET_AMD 1000000 +#define CC_RDNA2 (CC_OFFSET_AMD + 1030) + +#define GGML_CUDA_MAX_NODES 8192 + +// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication +// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant +// for large computational tasks. the drawback is that this requires some extra amount of VRAM: +// - 7B quantum model: +100-200 MB +// - 13B quantum model: +200-400 MB +// +//#define GGML_CUDA_FORCE_MMQ + +// TODO: improve this to be correct for more hardware +// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores +// probably other such cases, and not sure what happens on AMD hardware +#if !defined(GGML_CUDA_FORCE_MMQ) +#define CUDA_USE_TENSOR_CORES +#endif + +// max batch size to use MMQ kernels when tensor cores are available +#define MMQ_MAX_BATCH_SIZE 32 + +#if defined(GGML_USE_HIPBLAS) +#define __CUDA_ARCH__ 1300 + +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ + defined(__gfx1150__) || defined(__gfx1151__) +#define RDNA3 +#endif + +#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) +#define RDNA2 +#endif + +#ifndef __has_builtin + #define __has_builtin(x) 0 +#endif + +typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); +static __device__ __forceinline__ int __vsubss4(const int a, const int b) { + const int8x4_t va = reinterpret_cast(a); + const int8x4_t vb = reinterpret_cast(b); +#if __has_builtin(__builtin_elementwise_sub_sat) + const int8x4_t c = __builtin_elementwise_sub_sat(va, vb); + return reinterpret_cast(c); +#else + int8x4_t c; + int16_t tmp; +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp = va[i] - vb[i]; + if(tmp > std::numeric_limits::max()) tmp = std::numeric_limits::max(); + if(tmp < std::numeric_limits::min()) tmp = std::numeric_limits::min(); + c[i] = tmp; + } + return reinterpret_cast(c); +#endif // __has_builtin(__builtin_elementwise_sub_sat) +} + +static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { +#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) + c = __builtin_amdgcn_sdot4(a, b, c, false); +#elif defined(__gfx1100__) + c = __builtin_amdgcn_sudot4( true, a, true, b, c, false); +#elif defined(__gfx1010__) || defined(__gfx900__) + int tmp1; + int tmp2; + asm("\n \ + v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \ + v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \ + v_add3_u32 %0, %1, %2, %0 \n \ + v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \ + v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \ + v_add3_u32 %0, %1, %2, %0 \n \ + " + : "+v"(c), "=&v"(tmp1), "=&v"(tmp2) + : "v"(a), "v"(b) + ); +#else + const int8x4_t va = reinterpret_cast(a); + const int8x4_t vb = reinterpret_cast(b); + c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3]; +#endif + return c; +} +#endif // defined(GGML_USE_HIPBLAS) + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); + +#define CUDA_CHECK(err) \ + do { \ + cudaError_t err_ = (err); \ + if (err_ != cudaSuccess) { \ + int id; \ + cudaGetDevice(&id); \ + fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ + cudaGetErrorString(err_)); \ + fprintf(stderr, "current device: %d\n", id); \ + GGML_ASSERT(!"CUDA error"); \ + } \ + } while (0) + +#if CUDART_VERSION >= 12000 +#define CUBLAS_CHECK(err) \ + do { \ + cublasStatus_t err_ = (err); \ + if (err_ != CUBLAS_STATUS_SUCCESS) { \ + int id; \ + cudaGetDevice(&id); \ + fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \ + err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \ + fprintf(stderr, "current device: %d\n", id); \ + GGML_ASSERT(!"cuBLAS error"); \ + } \ + } while (0) +#else +#define CUBLAS_CHECK(err) \ + do { \ + cublasStatus_t err_ = (err); \ + if (err_ != CUBLAS_STATUS_SUCCESS) { \ + int id; \ + cudaGetDevice(&id); \ + fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ + fprintf(stderr, "current device: %d\n", id); \ + GGML_ASSERT(!"cuBLAS error"); \ + } \ + } while (0) +#endif // CUDART_VERSION >= 11 + +#if CUDART_VERSION >= 11100 +#define GGML_CUDA_ASSUME(x) __builtin_assume(x) +#else +#define GGML_CUDA_ASSUME(x) +#endif // CUDART_VERSION >= 11100 + +#ifdef GGML_CUDA_F16 +typedef half dfloat; // dequantize float +typedef half2 dfloat2; +#else +typedef float dfloat; // dequantize float +typedef float2 dfloat2; +#endif //GGML_CUDA_F16 + +static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { + const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment + + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) { + const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment + + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) { + return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} + +static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) { + return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} + +template +using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream); +typedef to_t_cuda_t to_fp32_cuda_t; +typedef to_t_cuda_t to_fp16_cuda_t; + +typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); +typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v); +typedef void (*cpy_kernel_t)(const char * cx, char * cdst); +typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); +typedef void (*ggml_cuda_op_mul_mat_t)( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, const cudaStream_t & stream); +typedef void (*ggml_cuda_op_flatten_t)( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream); + +// QK = number of values after dequantization +// QR = QK / number of values before dequantization +// QI = number of 32 bit integers before dequantization + +#define QK4_0 32 +#define QR4_0 2 +#define QI4_0 (QK4_0 / (4 * QR4_0)) +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +#define QR4_1 2 +#define QI4_1 (QK4_1 / (4 * QR4_1)) +typedef struct { + half2 dm; // dm.x = delta, dm.y = min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK5_0 32 +#define QR5_0 2 +#define QI5_0 (QK5_0 / (4 * QR5_0)) +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +#define QR5_1 2 +#define QI5_1 (QK5_1 / (4 * QR5_1)) +typedef struct { + half2 dm; // dm.x = delta, dm.y = min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + +#define QK8_0 32 +#define QR8_0 1 +#define QI8_0 (QK8_0 / (4 * QR8_0)) +typedef struct { + half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); + +#define QK8_1 32 +#define QR8_1 1 +#define QI8_1 (QK8_1 / (4 * QR8_1)) +typedef struct { + half2 ds; // ds.x = delta, ds.y = sum + int8_t qs[QK8_0]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding"); + +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); +typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc); +typedef void (*load_tiles_cuda_t)( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row); +typedef float (*vec_dot_q_mul_mat_cuda_t)( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k); + +//================================= k-quants + +#ifdef GGML_QKK_64 +#define QK_K 64 +#define K_SCALE_SIZE 4 +#else +#define QK_K 256 +#define K_SCALE_SIZE 12 +#endif + +#define QR2_K 4 +#define QI2_K (QK_K / (4*QR2_K)) +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + half2 dm; // super-block scale for quantized scales/mins +} block_q2_K; +static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); + +#define QR3_K 4 +#define QI3_K (QK_K / (4*QR3_K)) +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits +#ifdef GGML_QKK_64 + uint8_t scales[2]; // scales, quantized with 8 bits +#else + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits +#endif + half d; // super-block scale +} block_q3_K; +//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); + +#define QR4_K 2 +#define QI4_K (QK_K / (4*QR4_K)) +#ifdef GGML_QKK_64 +typedef struct { + half dm[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == sizeof(half2) + QK_K/2 + 2, "wrong q4_K block size/padding"); +#else +typedef struct { + half2 dm; // super-block scale for quantized scales/mins + uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); +#endif + +#define QR5_K 2 +#define QI5_K (QK_K / (4*QR5_K)) +#ifdef GGML_QKK_64 +typedef struct { + half d; // super-block scale + int8_t scales[QK_K/16]; // block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); +#else +typedef struct { + half2 dm; // super-block scale for quantized scales/mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif + +#define QR6_K 2 +#define QI6_K (QK_K / (4*QR6_K)) +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales + half d; // delta +} block_q6_K; +static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); + +#define WARP_SIZE 32 +#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses + +#define CUDA_GELU_BLOCK_SIZE 256 +#define CUDA_SILU_BLOCK_SIZE 256 +#define CUDA_TANH_BLOCK_SIZE 256 +#define CUDA_RELU_BLOCK_SIZE 256 +#define CUDA_SQR_BLOCK_SIZE 256 +#define CUDA_CPY_BLOCK_SIZE 32 +#define CUDA_SCALE_BLOCK_SIZE 256 +#define CUDA_CLAMP_BLOCK_SIZE 256 +#define CUDA_ROPE_BLOCK_SIZE 256 +#define CUDA_SOFT_MAX_BLOCK_SIZE 1024 +#define CUDA_ALIBI_BLOCK_SIZE 32 +#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 +#define CUDA_QUANTIZE_BLOCK_SIZE 256 +#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 +#define CUDA_GET_ROWS_BLOCK_SIZE 256 +#define CUDA_UPSCALE_BLOCK_SIZE 256 +#define CUDA_CONCAT_BLOCK_SIZE 256 +#define CUDA_PAD_BLOCK_SIZE 256 +#define CUDA_ACC_BLOCK_SIZE 256 +#define CUDA_IM2COL_BLOCK_SIZE 256 + +// dmmv = dequantize_mul_mat_vec +#ifndef GGML_CUDA_DMMV_X +#define GGML_CUDA_DMMV_X 32 +#endif +#ifndef GGML_CUDA_MMV_Y +#define GGML_CUDA_MMV_Y 1 +#endif + +#ifndef K_QUANTS_PER_ITERATION +#define K_QUANTS_PER_ITERATION 2 +#else +static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); +#endif + +#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE +#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128 +#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE + +#define MUL_MAT_SRC1_COL_STRIDE 128 + +#define MAX_STREAMS 8 +static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { { nullptr } }; + +struct ggml_tensor_extra_gpu { + void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors + cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs +}; + +// this is faster on Windows +// probably because the Windows CUDA libraries forget to make this check before invoking the drivers +inline cudaError_t ggml_cuda_set_device(const int device) { + int current_device; + CUDA_CHECK(cudaGetDevice(¤t_device)); + + if (device == current_device) { + return cudaSuccess; + } + + return cudaSetDevice(device); +} + +static int g_device_count = -1; +static int g_main_device = 0; +static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; +static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; + +static void * g_scratch_buffer = nullptr; +static size_t g_scratch_size = 0; // disabled by default +static size_t g_scratch_offset = 0; + +static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; + +static __device__ __forceinline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask, 32); + } + return x; +} + +static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); + a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); + } + return a; +} + +static __device__ __forceinline__ float warp_reduce_max(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +} + +static __device__ __forceinline__ float op_repeat(const float a, const float b) { + return b; +} + +static __device__ __forceinline__ float op_add(const float a, const float b) { + return a + b; +} + +static __device__ __forceinline__ float op_mul(const float a, const float b) { + return a * b; +} + +static __device__ __forceinline__ float op_div(const float a, const float b) { + return a / b; +} + +template +static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s10,*/ int s11, int s12, int s13) { + const int i0s = blockDim.x*blockIdx.x + threadIdx.x; + const int i1 = (blockDim.y*blockIdx.y + threadIdx.y); + const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3; + const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3; + + if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i_src0; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) { + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + } +} + +template +static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s10,*/ int s11, int s12, int s13) { + + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + const int i3 = i/(ne2*ne1*ne0); + const int i2 = (i/(ne1*ne0)) % ne2; + const int i1 = (i/ne0) % ne1; + const int i0 = i % ne0; + + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i_src0; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); +} + +static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne, + const int ne10, const int ne11, const int ne12, + const int nb1, const int nb2, int offset) { + const int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i >= ne) { + return; + } + int src1_idx = i - offset; + int oz = src1_idx / nb2; + int oy = (src1_idx - (oz * nb2)) / nb1; + int ox = src1_idx % nb1; + if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) { + dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11]; + } else { + dst[i] = x[i]; + } +} + +static __global__ void gelu_f32(const float * x, float * dst, const int k) { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + float xi = x[i]; + dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))); +} + +static __global__ void silu_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = x[i] / (1.0f + expf(-x[i])); +} + +static __global__ void gelu_quick_f32(const float *x, float *dst, int k) { + const float GELU_QUICK_COEF = -1.702f; + const int i = blockDim.x*blockIdx.x + threadIdx.x; + if (i >= k) { + return; + } + dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i]))); +} + +static __global__ void tanh_f32(const float *x, float *dst, int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + if (i >= k) { + return; + } + dst[i] = tanhf(x[i]); +} + +static __global__ void relu_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = fmaxf(x[i], 0); +} + +static __global__ void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + if (i >= k) { + return; + } + dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope; +} + +static __global__ void sqr_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = x[i] * x[i]; +} + +template +static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + float2 mean_var = make_float2(0.f, 0.f); + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row*ncols + col]; + mean_var.x += xi; + mean_var.y += xi * xi; + } + + // sum up partial sums + mean_var = warp_reduce_sum(mean_var); + if (block_size > WARP_SIZE) { + __shared__ float2 s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = mean_var; + } + __syncthreads(); + mean_var = s_sum[lane_id]; + mean_var = warp_reduce_sum(mean_var); + } + + const float mean = mean_var.x / ncols; + const float var = mean_var.y / ncols - mean * mean; + const float inv_std = rsqrtf(var + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std; + } +} + +static __global__ void concat_f32(const float *x,const float *y, float *dst, const int ne0, const int ne02) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + // operation + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + if (blockIdx.z < ne02) { // src0 + int offset_src = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + nidx + + blockIdx.y * ne0 + + (blockIdx.z - ne02) * ne0 * gridDim.y; + dst[offset_dst] = y[offset_src]; + } +} + +static __global__ void upscale_f32(const float *x, float *dst, const int ne00, const int nb02, const int scale_factor) { + int ne0 = ne00 * scale_factor; + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + // operation + int i00 = nidx / scale_factor; + int i01 = blockIdx.y / scale_factor; + int offset_src = + i00 + + i01 * ne00 + + blockIdx.z * nb02; + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + dst[offset_dst] = x[offset_src]; +} + +static __global__ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + + // operation + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) { + int offset_src = + nidx + + blockIdx.y * ne00 + + blockIdx.z * ne00 * ne01; + dst[offset_dst] = x[offset_src]; + } else { + dst[offset_dst] = 0.0f; + } +} + +template +static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) { + int start = blockIdx.x * group_size; + int end = start + group_size; + + start += threadIdx.x; + + if (end >= ne_elements) { + end = ne_elements; + } + + float tmp = 0.0f; // partial sum for thread in warp + + for (int j = start; j < end; j += block_size) { + tmp += x[j]; + } + + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + float mean = tmp / group_size; + tmp = 0.0f; + + for (int j = start; j < end; j += block_size) { + float xi = x[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + float variance = tmp / group_size; + float scale = rsqrtf(variance + eps); + for (int j = start; j < end; j += block_size) { + dst[j] *= scale; + } +} + +template +static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row*ncols + col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = scale * x[row*ncols + col]; + } +} + +static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q4_0 * x = (const block_q4_0 *) vx; + + const dfloat d = x[ib].d; + + const int vui = x[ib].qs[iqs]; + + v.x = vui & 0xF; + v.y = vui >> 4; + +#ifdef GGML_CUDA_F16 + v = __hsub2(v, {8.0f, 8.0f}); + v = __hmul2(v, {d, d}); +#else + v.x = (v.x - 8.0f) * d; + v.y = (v.y - 8.0f) * d; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q4_1 * x = (const block_q4_1 *) vx; + + const dfloat d = __low2half(x[ib].dm); + const dfloat m = __high2half(x[ib].dm); + + const int vui = x[ib].qs[iqs]; + + v.x = vui & 0xF; + v.y = vui >> 4; + +#ifdef GGML_CUDA_F16 + v = __hmul2(v, {d, d}); + v = __hadd2(v, {m, m}); +#else + v.x = (v.x * d) + m; + v.y = (v.y * d) + m; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q5_0 * x = (const block_q5_0 *) vx; + + const dfloat d = x[ib].d; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); + +#ifdef GGML_CUDA_F16 + v = __hsub2(v, {16.0f, 16.0f}); + v = __hmul2(v, {d, d}); +#else + v.x = (v.x - 16.0f) * d; + v.y = (v.y - 16.0f) * d; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q5_1 * x = (const block_q5_1 *) vx; + + const dfloat d = __low2half(x[ib].dm); + const dfloat m = __high2half(x[ib].dm); + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); + +#ifdef GGML_CUDA_F16 + v = __hmul2(v, {d, d}); + v = __hadd2(v, {m, m}); +#else + v.x = (v.x * d) + m; + v.y = (v.y * d) + m; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q8_0 * x = (const block_q8_0 *) vx; + + const dfloat d = x[ib].d; + + v.x = x[ib].qs[iqs + 0]; + v.y = x[ib].qs[iqs + 1]; + +#ifdef GGML_CUDA_F16 + v = __hmul2(v, {d, d}); +#else + v.x *= d; + v.y *= d; +#endif // GGML_CUDA_F16 +} + +//================================== k-quants + +template +static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_q2_K * x = (const block_q2_K *) vx; + + const int tid = threadIdx.x; +#if QK_K == 256 + const int n = tid/32; + const int l = tid - 32*n; + const int is = 8*n + l/16; + + const uint8_t q = x[i].qs[32*n + l]; + dst_t * y = yy + i*QK_K + 128*n; + + float dall = __low2half(x[i].dm); + float dmin = __high2half(x[i].dm); + y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); + y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); + y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); +#else + const int is = tid/16; // 0 or 1 + const int il = tid%16; // 0...15 + const uint8_t q = x[i].qs[il] >> (2*is); + dst_t * y = yy + i*QK_K + 16*is + il; + float dall = __low2half(x[i].dm); + float dmin = __high2half(x[i].dm); + y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); +#endif + +} + +template +static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_q3_K * x = (const block_q3_K *) vx; + +#if QK_K == 256 + const int r = threadIdx.x/4; + const int tid = r/2; + const int is0 = r%2; + const int l0 = 16*is0 + 4*(threadIdx.x%4); + const int n = tid / 4; + const int j = tid - 4*n; + + uint8_t m = 1 << (4*n + j); + int is = 8*n + 2*j + is0; + int shift = 2*j; + + int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : + (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4); + float d_all = x[i].d; + float dl = d_all * (us - 32); + + dst_t * y = yy + i*QK_K + 128*n + 32*j; + const uint8_t * q = x[i].qs + 32*n; + const uint8_t * hm = x[i].hmask; + + for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); +#else + const int tid = threadIdx.x; + const int is = tid/16; // 0 or 1 + const int il = tid%16; // 0...15 + const int im = il/8; // 0...1 + const int in = il%8; // 0...7 + + dst_t * y = yy + i*QK_K + 16*is + il; + + const uint8_t q = x[i].qs[il] >> (2*is); + const uint8_t h = x[i].hmask[in] >> (2*is + im); + const float d = (float)x[i].d; + + if (is == 0) { + y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } else { + y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } +#endif + +} + +#if QK_K == 256 +static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { + if (j < 4) { + d = q[j] & 63; m = q[j + 4] & 63; + } else { + d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} +#endif + +template +static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q4_K * x = (const block_q4_K *) vx; + + const int i = blockIdx.x; + +#if QK_K == 256 + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int is = 2*il; + const int n = 4; + + dst_t * y = yy + i*QK_K + 64*il + n*ir; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint8_t * q = x[i].qs + 32*il + n*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + for (int l = 0; l < n; ++l) { + y[l + 0] = d1 * (q[l] & 0xF) - m1; + y[l +32] = d2 * (q[l] >> 4) - m2; + } +#else + const int tid = threadIdx.x; + const uint8_t * q = x[i].qs; + dst_t * y = yy + i*QK_K; + const float d = (float)x[i].dm[0]; + const float m = (float)x[i].dm[1]; + y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); + y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4); +#endif +} + +template +static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q5_K * x = (const block_q5_K *) vx; + + const int i = blockIdx.x; + +#if QK_K == 256 + // assume 64 threads - this is very slightly better than the one below + const int tid = threadIdx.x; + const int il = tid/16; // il is in 0...3 + const int ir = tid%16; // ir is in 0...15 + const int is = 2*il; // is is in 0...6 + + dst_t * y = yy + i*QK_K + 64*il + 2*ir; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint8_t * ql = x[i].qs + 32*il + 2*ir; + const uint8_t * qh = x[i].qh + 2*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + + uint8_t hm = 1 << (2*il); + y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1; + y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1; + hm <<= 1; + y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; + y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; +#else + const int tid = threadIdx.x; + const uint8_t q = x[i].qs[tid]; + const int im = tid/8; // 0...3 + const int in = tid%8; // 0...7 + const int is = tid/16; // 0 or 1 + const uint8_t h = x[i].qh[in] >> im; + const float d = x[i].d; + dst_t * y = yy + i*QK_K + tid; + y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16)); + y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16)); +#endif +} + +template +static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q6_K * x = (const block_q6_K *) vx; + + const int i = blockIdx.x; +#if QK_K == 256 + + // assume 64 threads - this is very slightly better than the one below + const int tid = threadIdx.x; + const int ip = tid/32; // ip is 0 or 1 + const int il = tid - 32*ip; // 0...32 + const int is = 8*ip + il/16; + + dst_t * y = yy + i*QK_K + 128*ip + il; + + const float d = x[i].d; + + const uint8_t * ql = x[i].ql + 64*ip + il; + const uint8_t qh = x[i].qh[32*ip + il]; + const int8_t * sc = x[i].scales + is; + + y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); + y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); + y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +#else + + // assume 32 threads + const int tid = threadIdx.x; + const int ip = tid/16; // 0 or 1 + const int il = tid - 16*ip; // 0...15 + + dst_t * y = yy + i*QK_K + 16*ip + il; + + const float d = x[i].d; + + const uint8_t ql = x[i].ql[16*ip + il]; + const uint8_t qh = x[i].qh[il] >> (2*ip); + const int8_t * sc = x[i].scales; + + y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32); +#endif +} + +static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q2_K * x = (const block_q2_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int step = 16/K_QUANTS_PER_ITERATION; + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int s_offset = 8*im; + const int y_offset = 128*im + l0; + + uint32_t aux[4]; + const uint8_t * d = (const uint8_t *)aux; + const uint8_t * m = (const uint8_t *)(aux + 2); + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = x[i].qs + q_offset; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); + aux[0] = a[0] & 0x0f0f0f0f; + aux[1] = a[1] & 0x0f0f0f0f; + aux[2] = (a[0] >> 4) & 0x0f0f0f0f; + aux[3] = (a[1] >> 4) & 0x0f0f0f0f; + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) + + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) + + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) + + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3) + + y[l+16] * d[1] * ((q[l+16] >> 0) & 3) + + y[l+48] * d[3] * ((q[l+16] >> 2) & 3) + + y[l+80] * d[5] * ((q[l+16] >> 4) & 3) + +y[l+112] * d[7] * ((q[l+16] >> 6) & 3); + sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6] + + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7]; + + } + tmp += dall * sum1 - dmin * sum2; + + } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; + + uint32_t uaux[2]; + const uint8_t * d = (const uint8_t *)uaux; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint32_t * s = (const uint32_t *)x[i].scales; + + uaux[0] = s[0] & 0x0f0f0f0f; + uaux[1] = (s[0] >> 4) & 0x0f0f0f0f; + + const float2 dall = __half22float2(x[i].dm); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t ql = q[l]; + sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) + + y[l+16] * d[1] * ((ql >> 2) & 3) + + y[l+32] * d[2] * ((ql >> 4) & 3) + + y[l+48] * d[3] * ((ql >> 6) & 3); + sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7]; + } + tmp += dall.x * sum1 - dall.y * sum2; + } +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q3_K * x = (const block_q3_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop + const int step = 16/K_QUANTS_PER_ITERATION; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0....15 or 0...7 + + const uint8_t m = 1 << (4*im); + + const int l0 = n*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int y_offset = 128*im + l0; + + uint16_t utmp[4]; + const int8_t * s = (const int8_t *)utmp; + + const uint16_t s_shift = 4*im; + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = x[i].qs + q_offset; + const uint8_t * h = x[i].hmask + l0; + + const uint16_t * a = (const uint16_t *)x[i].scales; + utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); + utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); + utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); + utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); + + const float d = x[i].d; + + float sum = 0; + for (int l = 0; l < n; ++l) { + sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) + + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) + + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) + + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); + sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) + + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) + + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) + + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); + } + tmp += d * sum; + + } +#else + + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 + const int in = offset/8; // 0 or 1 + const int im = offset%8; // 0...7 + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint8_t * s = x[i].scales; + + const float dall = (float)x[i].d; + + float sum = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t hl = x[i].hmask[im+l] >> in; + const uint8_t ql = q[l]; + sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) + + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4)) + + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4)) + + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4)); + } + tmp += sum; + } +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q4_K * x = (const block_q4_K *)vx + ib0; + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + + const int il = tid/step; // 0...3 + const int ir = tid - step*il; // 0...7 or 0...3 + const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + +#if K_QUANTS_PER_ITERATION == 2 + uint32_t q32[4]; + const uint8_t * q4 = (const uint8_t *)q32; +#else + uint16_t q16[4]; + const uint8_t * q4 = (const uint8_t *)q16; +#endif + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + +#if K_QUANTS_PER_ITERATION == 2 + const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); + const uint32_t * q2 = q1 + 16; + + q32[0] = q1[0] & 0x0f0f0f0f; + q32[1] = q1[0] & 0xf0f0f0f0; + q32[2] = q2[0] & 0x0f0f0f0f; + q32[3] = q2[0] & 0xf0f0f0f0; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 4; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4]; + s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#else + const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); + const uint16_t * q2 = q1 + 32; + + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[0] & 0xf0f0; + q16[2] = q2[0] & 0x0f0f; + q16[3] = q2[0] & 0xf0f0; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 2; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; + s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#endif + + } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + + const int step = tid * K_QUANTS_PER_ITERATION; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + float tmp = 0; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const float * y = yy + i*QK_K + step; + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + const float d = (float)x[i].dm[0]; + const float m = (float)x[i].dm[1]; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) + + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) + + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); + } + tmp += sum; + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) { + + const int row = blockIdx.x; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q5_K * x = (const block_q5_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = threadIdx.x/2; // 0...15 + const int ix = threadIdx.x%2; + + const int il = tid/4; // 0...3 + const int ir = tid - 4*il;// 0...3 + const int n = 2; + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + const uint8_t hm1 = 1 << (2*im); + const uint8_t hm2 = hm1 << 4; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + + uint16_t q16[8]; + const uint8_t * q4 = (const uint8_t *)q16; + + for (int i = ix; i < num_blocks_per_row; i += 2) { + + const uint8_t * ql1 = x[i].qs + q_offset; + const uint8_t * qh = x[i].qh + l0; + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + + float4 sum = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + const uint16_t * q1 = (const uint16_t *)ql1; + const uint16_t * q2 = q1 + 32; + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[8] & 0x0f0f; + q16[2] = (q1[0] >> 4) & 0x0f0f; + q16[3] = (q1[8] >> 4) & 0x0f0f; + q16[4] = q2[0] & 0x0f0f; + q16[5] = q2[8] & 0x0f0f; + q16[6] = (q2[0] >> 4) & 0x0f0f; + q16[7] = (q2[8] >> 4) & 0x0f0f; + for (int l = 0; l < n; ++l) { + sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0)) + + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0)); + sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0)) + + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0)); + sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0)) + + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0)); + sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0)) + + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0)); + smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3] + + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; + } + tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; + } + +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + const int step = tid * K_QUANTS_PER_ITERATION; + const int im = step/8; + const int in = step%8; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const int8_t * s = x[i].scales; + const float * y = yy + i*QK_K + step; + const float d = x[i].d; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + const uint8_t h = x[i].qh[in+j] >> im; + sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) + + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) + + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16)) + + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16)); + } + tmp += sum; + } +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q6_K * x = (const block_q6_K *)vx + ib0; + +#if QK_K == 256 + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif + const int ql_offset = 64*im + l0; + const int qh_offset = 32*im + l0; + const int s_offset = 8*im + is; + const int y_offset = 128*im + l0; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * ql = x[i].ql + ql_offset; + const uint8_t * qh = x[i].qh + qh_offset; + const int8_t * s = x[i].scales + s_offset; + + const float d = x[i].d; + +#if K_QUANTS_PER_ITERATION == 1 + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) + + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) + + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) + + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) + +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); + tmp += sum; +#else + float sum = 0; + for (int l = 0; l < 4; ++l) { + sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + } + tmp += sum; +#endif + + } + +#else + + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3 + + const int step = tid * K_QUANTS_PER_ITERATION; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + step; + const uint8_t * ql = x[i].ql + step; + const uint8_t * qh = x[i].qh + step; + const int8_t * s = x[i].scales; + + const float d = x[i+0].d; + + float sum = 0; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) + + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) + + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); + } + tmp += sum; + + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const half * x = (const half *) vx; + + // automatic half -> float type cast if dfloat == float + v.x = x[ib + iqs + 0]; + v.y = x[ib + iqs + 1]; +} + +static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const float * x = (const float *) vx; + + // automatic half -> float type cast if dfloat == float + v.x = x[ib + iqs + 0]; + v.y = x[ib + iqs + 1]; +} + +static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { + const int ix = blockDim.x*blockIdx.x + threadIdx.x; + + if (ix >= kx_padded) { + return; + } + + const int iy = blockDim.y*blockIdx.y + threadIdx.y; + + const int i_padded = iy*kx_padded + ix; + + block_q8_1 * y = (block_q8_1 *) vy; + + const int ib = i_padded / QK8_1; // block index + const int iqs = i_padded % QK8_1; // quant index + + const float xi = ix < kx ? x[iy*kx + ix] : 0.0f; + float amax = fabsf(xi); + float sum = xi; + +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32)); + sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); + } + + const float d = amax / 127; + const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + y[ib].qs[iqs] = q; + + if (iqs > 0) { + return; + } + + reinterpret_cast(y[ib].ds.x) = d; + reinterpret_cast(y[ib].ds.y) = sum; +} + +template +static __global__ void k_get_rows( + const void * src0, const int32_t * src1, dst_t * dst, + int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ + /*size_t s0,*/ size_t s1, size_t s2, size_t s3, + /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, + size_t s10, size_t s11, size_t s12/*, size_t s13*/) { + + const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2; + const int i10 = blockDim.y*blockIdx.y + threadIdx.y; + const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; + const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; + + if (i00 >= ne00) { + return; + } + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03; + + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index + const int iybs = i00 - i00%qk; // dst block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + dfloat2 v; + dequantize_kernel(src0_row, ib, iqs, v); + + dst_row[iybs + iqs + 0] = v.x; + dst_row[iybs + iqs + y_offset] = v.y; +} + +template +static __global__ void k_get_rows_float( + const src0_t * src0, const int32_t * src1, dst_t * dst, + int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ + /*size_t s0,*/ size_t s1, size_t s2, size_t s3, + /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, + size_t s10, size_t s11, size_t s12/*, size_t s13*/) { + + const int i00 = blockIdx.x*blockDim.x + threadIdx.x; + const int i10 = blockDim.y*blockIdx.y + threadIdx.y; + const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; + const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; + + if (i00 >= ne00) { + return; + } + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03); + + dst_row[i00] = src0_row[i00]; +} + +template +static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { + const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; + + if (i >= k) { + return; + } + + const int ib = i/qk; // block index + const int iqs = (i%qk)/qr; // quant index + const int iybs = i - i%qk; // y block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + dfloat2 v; + dequantize_kernel(vx, ib, iqs, v); + + y[iybs + iqs + 0] = v.x; + y[iybs + iqs + y_offset] = v.y; +} + +// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called +// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q + +#define VDR_Q4_0_Q8_1_MMVQ 2 +#define VDR_Q4_0_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( + const int * v, const int * u, const float & d4, const half2 & ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = __dp4a(vi0, u[2*i+0], sumi); + sumi = __dp4a(vi1, u[2*i+1], sumi); + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 8 from each quant value + return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y); +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q4_1_Q8_1_MMVQ 2 +#define VDR_Q4_1_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( + const int * v, const int * u, const half2 & dm4, const half2 & ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = __dp4a(vi0, u[2*i+0], sumi); + sumi = __dp4a(vi1, u[2*i+1], sumi); + } + +#ifdef GGML_CUDA_F16 + const float2 tmp = __half22float2(__hmul2(dm4, ds8)); + const float d4d8 = tmp.x; + const float m4s8 = tmp.y; +#else + const float2 dm4f = __half22float2(dm4); + const float2 ds8f = __half22float2(ds8); + const float d4d8 = dm4f.x * ds8f.x; + const float m4s8 = dm4f.y * ds8f.y; +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it + return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1)); +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q5_0_Q8_1_MMVQ 2 +#define VDR_Q5_0_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl( + const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 16 from each quant value + return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y); +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q5_1_Q8_1_MMVQ 2 +#define VDR_Q5_1_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( + const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + } + +#ifdef GGML_CUDA_F16 + const float2 tmp = __half22float2(__hmul2(dm5, ds8)); + const float d5d8 = tmp.x; + const float m5s8 = tmp.y; +#else + const float2 dm5f = __half22float2(dm5); + const float2 ds8f = __half22float2(ds8); + const float d5d8 = dm5f.x * ds8f.x; + const float m5s8 = dm5f.y * ds8f.y; +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it + return sumi*d5d8 + m5s8 / (QI5_1 / vdr); + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q8_0_Q8_1_MMVQ 2 +#define VDR_Q8_0_Q8_1_MMQ 8 + +template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( + const int * v, const int * u, const float & d8_0, const float & d8_1) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = __dp4a(v[i], u[i], sumi); + } + + return d8_0*d8_1 * sumi; +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl( + const int * v, const int * u, const half2 & dm8, const half2 & ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = __dp4a(v[i], u[i], sumi); + } + +#ifdef GGML_CUDA_F16 + const float2 tmp = __half22float2(__hmul2(dm8, ds8)); + const float d8d8 = tmp.x; + const float m8s8 = tmp.y; +#else + const float2 dm8f = __half22float2(dm8); + const float2 ds8f = __half22float2(ds8); + const float d8d8 = dm8f.x * ds8f.x; + const float m8s8 = dm8f.y * ds8f.y; +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it + return sumi*d8d8 + m8s8 / (QI8_1 / vdr); +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q2_K_Q8_1_MMVQ 1 +#define VDR_Q2_K_Q8_1_MMQ 2 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( + const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm2, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR2_K; ++i) { + const int sc = scales[2*i]; + + const int vi = (v >> (2*i)) & 0x03030303; + + sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values + } + + const float2 dm2f = __half22float2(dm2); + + return dm2f.x*sumf_d - dm2f.y*sumf_m; +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm2, const float & d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi_d = 0; + int sumi_m = 0; + +#pragma unroll + for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { + int sumi_d_sc = 0; + + const int sc = scales[i0 / (QI8_1/2)]; + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + +#pragma unroll + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product + sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m + } + + sumi_d += sumi_d_sc * (sc & 0xF); + } + + const float2 dm2f = __half22float2(dm2); + + return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m); +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q3_K_Q8_1_MMVQ 1 +#define VDR_Q3_K_Q8_1_MMQ 2 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const int & scale_offset, const float & d3, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + const int isc = scale_offset + 2*i; + + const int isc_low = isc % (QK_K/32); + const int sc_shift_low = 4 * (isc / (QK_K/32)); + const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF; + + const int isc_high = isc % (QK_K/64); + const int sc_shift_high = 2 * (isc / (QK_K/64)); + const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; + + const int sc = (sc_low | sc_high) - 32; + + const int vil = (vl >> (2*i)) & 0x03030303; + + const int vih = ((vh >> i) << 2) & 0x04040404; + + const int vi = __vsubss4(vil, vih); + + sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d3 * sumf; +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d3, const float & d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi = 0; + +#pragma unroll + for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) { + int sumi_sc = 0; + + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product + } + + sumi += sumi_sc * scales[i0 / (QI8_1/2)]; + } + + return d3*d8 * sumi; +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q4_K_Q8_1_MMVQ 2 +#define VDR_Q4_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K; ++i) { + const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F; + const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F; + + const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = __dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product + } + + const float2 ds8f = __half22float2(ds8[i]); + + sumf_d += ds8f.x * (sc[i] * sumi_d); + sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q5_K_Q8_1_MMVQ 2 +#define VDR_Q5_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( + const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F; + const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F; + + const int vh0i = ((vh[0] >> i) << 4) & 0x10101010; + const int vh1i = ((vh[1] >> i) << 4) & 0x10101010; + + const int v0i = vl0i | vh0i; + const int v1i = vl1i | vh1i; + + const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); + + } + + const float2 dm5f = __half22float2(dm5); + + return dm5f.x*sumf_d - dm5f.y*sumf_m; + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = __dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product + } + + const float2 ds8f = __half22float2(ds8[i]); + + sumf_d += ds8f.x * (sc[i] * sumi_d); + sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q6_K_Q8_1_MMVQ 1 +#define VDR_Q6_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + const int sc = scales[4*i]; + + const int vil = (vl >> (4*i)) & 0x0F0F0F0F; + + const int vih = ((vh >> (4*i)) << 4) & 0x30303030; + + const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 + + sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d*sumf; +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, + const float & d6, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + +#pragma unroll + for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { + int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale + +#pragma unroll + for (int i = i0; i < i0 + 2; ++i) { + sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product + sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product + + sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product + sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product + } + + sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y); + } + + return d6 * sumf_d; + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +static __device__ __forceinline__ float vec_dot_q4_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; + + int v[VDR_Q4_0_Q8_1_MMVQ]; + int u[2*VDR_Q4_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); + } + + return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds); +} + +template static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + (void)x_qh; (void)x_sc; + + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0]; + + *x_ql = tile_x_qs; + *x_dm = (half2 *) tile_x_d; +} + +template static __device__ __forceinline__ void load_tiles_q4_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + (void)x_qh; (void)x_sc; + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI4_0; + const int kqsx = k % QI4_0; + + const block_q4_0 * bx0 = (const block_q4_0 *) vx; + + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { + int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d; + } +} + +static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + (void)x_qh; (void)x_sc; + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const float * x_dmf = (const float *) x_dm; + + int u[2*VDR_Q4_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; + } + + return vec_dot_q4_0_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +static __device__ __forceinline__ float vec_dot_q4_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; + + int v[VDR_Q4_1_Q8_1_MMVQ]; + int u[2*VDR_Q4_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1); + } + + return vec_dot_q4_1_q8_1_impl(v, u, bq4_1->dm, bq8_1->ds); +} + +template static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + (void)x_qh; (void)x_sc; + + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1]; + + *x_ql = tile_x_qs; + *x_dm = tile_x_dm; +} + +template static __device__ __forceinline__ void load_tiles_q4_1( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + (void)x_qh; (void)x_sc; + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI4_1; + const int kqsx = k % QI4_1; + + const block_q4_1 * bx0 = (const block_q4_1 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { + int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; + } +} + +static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + (void)x_qh; (void)x_sc; + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + + int u[2*VDR_Q4_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; + } + + return vec_dot_q4_1_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +static __device__ __forceinline__ float vec_dot_q5_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; + + int vl[VDR_Q5_0_Q8_1_MMVQ]; + int vh[VDR_Q5_0_Q8_1_MMVQ]; + int u[2*VDR_Q5_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i); + vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0); + } + + return vec_dot_q5_0_q8_1_impl(vl, vh, u, bq5_0->d, bq8_1->ds); +} + +template static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + (void)x_qh; (void)x_sc; + + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0]; + + *x_ql = tile_x_ql; + *x_dm = (half2 *) tile_x_d; +} + +template static __device__ __forceinline__ void load_tiles_q5_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + (void)x_qh; (void)x_sc; + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI5_0; + const int kqsx = k % QI5_0; + + const block_q5_0 * bx0 = (const block_q5_0 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx; + + const int ql = get_int_from_uint8(bxi->qs, kqsx); + const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0)); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; + const int kbxd = k % blocks_per_tile_x_row; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { + int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d; + } +} + +static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + (void)x_qh; (void)x_sc; + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + int u[2*VDR_Q5_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; + } + + return vec_dot_q8_0_q8_1_impl + (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +static __device__ __forceinline__ float vec_dot_q5_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; + + int vl[VDR_Q5_1_Q8_1_MMVQ]; + int vh[VDR_Q5_1_Q8_1_MMVQ]; + int u[2*VDR_Q5_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i); + vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1); + } + + return vec_dot_q5_1_q8_1_impl(vl, vh, u, bq5_1->dm, bq8_1->ds); +} + +template static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + (void)x_qh; (void)x_sc; + + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; +} + +template static __device__ __forceinline__ void load_tiles_q5_1( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + (void)x_qh; (void)x_sc; + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI5_1; + const int kqsx = k % QI5_1; + + const block_q5_1 * bx0 = (const block_q5_1 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx; + + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1)); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) { + int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; + } +} + +static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + (void)x_qh; (void)x_sc; + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; + + int u[2*VDR_Q5_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; + } + + return vec_dot_q8_1_q8_1_impl + (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +static __device__ __forceinline__ float vec_dot_q8_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; + + int v[VDR_Q8_0_Q8_1_MMVQ]; + int u[VDR_Q8_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_int8(bq8_0->qs, iqs + i); + u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + } + + return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds)); +} + +template static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + (void)x_qh; (void)x_sc; + + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0]; + + *x_ql = tile_x_qs; + *x_dm = (half2 *) tile_x_d; +} + +template static __device__ __forceinline__ void load_tiles_q8_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + (void)x_qh; (void)x_sc; + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI8_0; + const int kqsx = k % QI8_0; + float * x_dmf = (float *) x_dm; + + const block_q8_0 * bx0 = (const block_q8_0 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) { + int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; + } +} + +static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + (void)x_qh; (void)x_sc; + + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + return vec_dot_q8_0_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], + y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_q2_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q2_K * bq2_K = (const block_q2_K *) vbq; + + const int bq8_offset = QR2_K * (iqs / QI8_1); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const uint8_t * scales = bq2_K->scales + scale_offset; + + const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs); + int u[QR2_K]; + float d8[QR2_K]; + +#pragma unroll + for (int i = 0; i < QR2_K; ++ i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = __low2half(bq8_1[bq8_offset + i].ds); + } + + return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8); +} + +template static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + (void)x_qh; + + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q2_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + (void)x_qh; + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI2_K; + const int kqsx = k % QI2_K; + + const block_q2_K * bx0 = (const block_q2_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) { + int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4); + + x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4)); + } +} + +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + (void)x_qh; + + const int kbx = k / QI2_K; + const int ky = (k % QI2_K) * QR2_K; + const float * y_df = (const float *) y_ds; + + int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; + + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); + const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); + +#pragma unroll + for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) { + v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; + } + + const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; + + const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE; + return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_q3_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q3_K * bq3_K = (const block_q3_K *) vbq; + + const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const float d = bq3_K->d; + + const int vl = get_int_from_uint8(bq3_K->qs, iqs); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = __low2half(bq8_1[bq8_offset + i].ds); + } + + return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); +} + +template static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K]; + __shared__ int tile_x_qh[mmq_y * (WARP_SIZE/2) + mmq_y/2]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_qh = tile_x_qh; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q3_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI3_K; + const int kqsx = k % QI3_K; + + const block_q3_K * bx0 = (const block_q3_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; + const int kbxd = k % blocks_per_tile_x_row; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { + int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { + int i = i0 + i_offset * 2 + k / (WARP_SIZE/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4); + + const int ksc = k % (QI3_K/4); + + const int ksc_low = ksc % (QI3_K/8); + const int shift_low = 4 * (ksc / (QI3_K/8)); + const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; + + const int ksc_high = QI3_K/8; + const int shift_high = 2 * ksc; + const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; + + const int sc = __vsubss4(sc_low | sc_high, 0x20202020); + + x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc; + } +} + +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const int kbx = k / QI3_K; + const int ky = (k % QI3_K) * QR3_K; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; + + int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); + const int shift = 2 * ((ky % 32) / 8); + const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; + + const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); + const int vlh = (vh << 2) & 0x04040404; + + v[l] = __vsubss4(vll, vlh); + } + + const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE; + return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + +#ifndef GGML_QKK_64 + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + int v[2]; + int u[2*QR4_K]; + float d8[QR4_K]; + + // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6 + const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2)); + + // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 + // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 + // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 + // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 + + const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + v[0] = q4[0]; + v[1] = q4[4]; + + const uint16_t * scales = (const uint16_t *)bq4_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + + for (int i = 0; i < QR4_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = __low2half(bq8i->ds); + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); + +#else + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + const uint16_t * a = (const uint16_t *)bq4_K->scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const float dall = bq4_K->dm[0]; + const float dmin = bq4_K->dm[1]; + + const float d8_1 = __low2float(bq8_1[0].ds); + const float d8_2 = __low2float(bq8_1[1].ds); + + const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); + const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); + const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); + const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); + + const int * q4 = (const int *)bq4_K->qs + (iqs/2); + const int v1 = q4[0]; + const int v2 = q4[4]; + + const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0)); + const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); + const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); + const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0)); + + sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); + sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); + + return dall * sumf_d - dmin * sumf_m; + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A + +#endif +} + +template static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + (void)x_qh; + + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q4_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + (void)x_qh; + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI4_K; // == 0 if QK_K == 256 + const int kqsx = k % QI4_K; // == k if QK_K == 256 + + const block_q4_K * bx0 = (const block_q4_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) { + int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd; + +#if QK_K == 256 + x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; +#else + x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]}; +#endif + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = k % (WARP_SIZE/8); + + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + } +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + (void)x_qh; + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); + + const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE; + return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8, + x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_q5_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + +#ifndef GGML_QKK_64 + const block_q5_K * bq5_K = (const block_q5_K *) vbq; + + int vl[2]; + int vh[2]; + int u[2*QR5_K]; + float d8[QR5_K]; + + const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2)); + const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4)); + + vl[0] = ql[0]; + vl[1] = ql[4]; + + vh[0] = qh[0] >> bq8_offset; + vh[1] = qh[4] >> bq8_offset; + + const uint16_t * scales = (const uint16_t *)bq5_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = __low2float(bq8i->ds); + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + + return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8); + +#else + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const block_q5_K * bq5_K = (const block_q5_K *) vbq; + + const int8_t * s = bq5_K->scales; + + const float d = bq5_K->d; + + const float d8_1 = __low2half(bq8_1[0].ds); + const float d8_2 = __low2half(bq8_1[1].ds); + + const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); + const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); + const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); + const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); + + const int * ql = (const int *)bq5_K->qs + (iqs/2); + const int vl1 = ql[0]; + const int vl2 = ql[4]; + + const int step = 4 * (iqs/2); // 0, 4, 8, 12 + const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6 + const int in = step%8; // 0, 4, 0, 4 + const int vh = (*((const int *)(bq5_K->qh + in))) >> im; + + const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f); + const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f); + const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f); + const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f); + + const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1]) + + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]); + + return d * sumf_d; + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A + +#endif +} + +template static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + (void)x_qh; + + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q5_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + (void)x_qh; + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI5_K; // == 0 if QK_K == 256 + const int kqsx = k % QI5_K; // == k if QK_K == 256 + + const block_q5_K * bx0 = (const block_q5_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx; + const int ky = QR5_K*kqsx; + + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010; + + const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0; + const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4); + + x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; + x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) { + int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd; + +#if QK_K == 256 + x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; +#endif + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = k % (WARP_SIZE/8); + + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + } +} + +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + (void)x_qh; + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); + + const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k; + const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE; + return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, + x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q6_K * bq6_K = (const block_q6_K *) vbq; + + const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4); + const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); + const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4)); + + const int vl = get_int_from_uint8(bq6_K->ql, iqs); + const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift; + + const int8_t * scales = bq6_K->scales + scale_offset; + + int u[QR6_K]; + float d8[QR6_K]; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); + d8[i] = __low2half(bq8_1[bq8_offset + 2*i].ds); + } + + return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); +} + +template static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + (void)x_qh; + + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q6_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + (void)x_qh; + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI6_K; // == 0 if QK_K == 256 + const int kqsx = k % QI6_K; // == k if QK_K == 256 + + const block_q6_K * bx0 = (const block_q6_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx; + const int ky = QR6_K*kqsx; + + const int ql = get_int_from_uint8(bxi->ql, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)); + const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030; + const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030; + + const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0; + const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2); + + x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { + int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4; + + x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8)); + } +} + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + (void)x_qh; + + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]); + + const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k; + const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE; + return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); +} + +template +static __device__ __forceinline__ void mul_mat_q( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_col_y = nrows_y / QK8_1; + const int blocks_per_warp = WARP_SIZE / qi; + + const int & ncols_dst = ncols_y; + + const int row_dst_0 = blockIdx.x*mmq_y; + const int & row_x_0 = row_dst_0; + + const int col_dst_0 = blockIdx.y*mmq_x; + const int & col_y_0 = col_dst_0; + + int * tile_x_ql = nullptr; + half2 * tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; + + allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc); + + __shared__ int tile_y_qs[mmq_x * WARP_SIZE]; + __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1]; + + float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}}; + + for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { + + load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, + threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x); + +#pragma unroll + for (int ir = 0; ir < qr; ++ir) { + const int kqs = ir*WARP_SIZE + threadIdx.x; + const int kbxd = kqs / QI8_1; + +#pragma unroll + for (int i = 0; i < mmq_x; i += nwarps) { + const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses + + const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd]; + + const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE; + tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); + } + +#pragma unroll + for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { + const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; + const int kby = threadIdx.x % (WARP_SIZE/QI8_1); + const int col_y_eff = min(col_y_0 + ids, ncols_y-1); + + // if the sum is not needed it's faster to transform the scale to f32 ahead of time + const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds; + half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; + if (need_sum) { + *dsi_dst = *dsi_src; + } else { + float * dfi_dst = (float *) dsi_dst; + *dfi_dst = __low2half(*dsi_src); + } + } + + __syncthreads(); + +// #pragma unroll // unrolling this loop causes too much register pressure + for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) { +#pragma unroll + for (int j = 0; j < mmq_x; j += nwarps) { +#pragma unroll + for (int i = 0; i < mmq_y; i += WARP_SIZE) { + sum[i/WARP_SIZE][j/nwarps] += vec_dot( + tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, + threadIdx.x + i, threadIdx.y + j, k); + } + } + } + + __syncthreads(); + } + } + +#pragma unroll + for (int j = 0; j < mmq_x; j += nwarps) { + const int col_dst = col_dst_0 + j + threadIdx.y; + + if (col_dst >= ncols_dst) { + return; + } + +#pragma unroll + for (int i = 0; i < mmq_y; i += WARP_SIZE) { + const int row_dst = row_dst_0 + threadIdx.x + i; + + if (row_dst >= nrows_dst) { + continue; + } + + dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps]; + } + } +} + +#define MMQ_X_Q4_0_RDNA2 64 +#define MMQ_Y_Q4_0_RDNA2 128 +#define NWARPS_Q4_0_RDNA2 8 +#define MMQ_X_Q4_0_RDNA1 64 +#define MMQ_Y_Q4_0_RDNA1 64 +#define NWARPS_Q4_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q4_0_AMPERE 4 +#define MMQ_Y_Q4_0_AMPERE 32 +#define NWARPS_Q4_0_AMPERE 4 +#else +#define MMQ_X_Q4_0_AMPERE 64 +#define MMQ_Y_Q4_0_AMPERE 128 +#define NWARPS_Q4_0_AMPERE 4 +#endif +#define MMQ_X_Q4_0_PASCAL 64 +#define MMQ_Y_Q4_0_PASCAL 64 +#define NWARPS_Q4_0_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q4_0_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + mul_mat_q4_0( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q4_0_RDNA2; + const int mmq_y = MMQ_Y_Q4_0_RDNA2; + const int nwarps = NWARPS_Q4_0_RDNA2; +#else + const int mmq_x = MMQ_X_Q4_0_RDNA1; + const int mmq_y = MMQ_Y_Q4_0_RDNA1; + const int nwarps = NWARPS_Q4_0_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q4_0_AMPERE; + const int mmq_y = MMQ_Y_Q4_0_AMPERE; + const int nwarps = NWARPS_Q4_0_AMPERE; + + mul_mat_q, + load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q4_0_PASCAL; + const int mmq_y = MMQ_Y_Q4_0_PASCAL; + const int nwarps = NWARPS_Q4_0_PASCAL; + + mul_mat_q, + load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q4_0_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q4_1_RDNA2 64 +#define MMQ_Y_Q4_1_RDNA2 128 +#define NWARPS_Q4_1_RDNA2 8 +#define MMQ_X_Q4_1_RDNA1 64 +#define MMQ_Y_Q4_1_RDNA1 64 +#define NWARPS_Q4_1_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q4_1_AMPERE 4 +#define MMQ_Y_Q4_1_AMPERE 32 +#define NWARPS_Q4_1_AMPERE 4 +#else +#define MMQ_X_Q4_1_AMPERE 64 +#define MMQ_Y_Q4_1_AMPERE 128 +#define NWARPS_Q4_1_AMPERE 4 +#endif +#define MMQ_X_Q4_1_PASCAL 64 +#define MMQ_Y_Q4_1_PASCAL 64 +#define NWARPS_Q4_1_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#elif __CUDA_ARCH__ < CC_VOLTA + __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2) +#endif // __CUDA_ARCH__ < CC_VOLTA + mul_mat_q4_1( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q4_1_RDNA2; + const int mmq_y = MMQ_Y_Q4_1_RDNA2; + const int nwarps = NWARPS_Q4_1_RDNA2; +#else + const int mmq_x = MMQ_X_Q4_1_RDNA1; + const int mmq_y = MMQ_Y_Q4_1_RDNA1; + const int nwarps = NWARPS_Q4_1_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q4_1_AMPERE; + const int mmq_y = MMQ_Y_Q4_1_AMPERE; + const int nwarps = NWARPS_Q4_1_AMPERE; + + mul_mat_q, + load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q4_1_PASCAL; + const int mmq_y = MMQ_Y_Q4_1_PASCAL; + const int nwarps = NWARPS_Q4_1_PASCAL; + + mul_mat_q, + load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q4_1_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q5_0_RDNA2 64 +#define MMQ_Y_Q5_0_RDNA2 128 +#define NWARPS_Q5_0_RDNA2 8 +#define MMQ_X_Q5_0_RDNA1 64 +#define MMQ_Y_Q5_0_RDNA1 64 +#define NWARPS_Q5_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q5_0_AMPERE 4 +#define MMQ_Y_Q5_0_AMPERE 32 +#define NWARPS_Q5_0_AMPERE 4 +#else +#define MMQ_X_Q5_0_AMPERE 128 +#define MMQ_Y_Q5_0_AMPERE 64 +#define NWARPS_Q5_0_AMPERE 4 +#endif +#define MMQ_X_Q5_0_PASCAL 64 +#define MMQ_Y_Q5_0_PASCAL 64 +#define NWARPS_Q5_0_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q5_0_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + mul_mat_q5_0( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q5_0_RDNA2; + const int mmq_y = MMQ_Y_Q5_0_RDNA2; + const int nwarps = NWARPS_Q5_0_RDNA2; +#else + const int mmq_x = MMQ_X_Q5_0_RDNA1; + const int mmq_y = MMQ_Y_Q5_0_RDNA1; + const int nwarps = NWARPS_Q5_0_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q5_0_AMPERE; + const int mmq_y = MMQ_Y_Q5_0_AMPERE; + const int nwarps = NWARPS_Q5_0_AMPERE; + + mul_mat_q, + load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q5_0_PASCAL; + const int mmq_y = MMQ_Y_Q5_0_PASCAL; + const int nwarps = NWARPS_Q5_0_PASCAL; + + mul_mat_q, + load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q5_0_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q5_1_RDNA2 64 +#define MMQ_Y_Q5_1_RDNA2 128 +#define NWARPS_Q5_1_RDNA2 8 +#define MMQ_X_Q5_1_RDNA1 64 +#define MMQ_Y_Q5_1_RDNA1 64 +#define NWARPS_Q5_1_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q5_1_AMPERE 4 +#define MMQ_Y_Q5_1_AMPERE 32 +#define NWARPS_Q5_1_AMPERE 4 +#else +#define MMQ_X_Q5_1_AMPERE 128 +#define MMQ_Y_Q5_1_AMPERE 64 +#define NWARPS_Q5_1_AMPERE 4 +#endif +#define MMQ_X_Q5_1_PASCAL 64 +#define MMQ_Y_Q5_1_PASCAL 64 +#define NWARPS_Q5_1_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q5_1_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +mul_mat_q5_1( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q5_1_RDNA2; + const int mmq_y = MMQ_Y_Q5_1_RDNA2; + const int nwarps = NWARPS_Q5_1_RDNA2; +#else + const int mmq_x = MMQ_X_Q5_1_RDNA1; + const int mmq_y = MMQ_Y_Q5_1_RDNA1; + const int nwarps = NWARPS_Q5_1_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q5_1_AMPERE; + const int mmq_y = MMQ_Y_Q5_1_AMPERE; + const int nwarps = NWARPS_Q5_1_AMPERE; + + mul_mat_q, + load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q5_1_PASCAL; + const int mmq_y = MMQ_Y_Q5_1_PASCAL; + const int nwarps = NWARPS_Q5_1_PASCAL; + + mul_mat_q, + load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q5_1_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q8_0_RDNA2 64 +#define MMQ_Y_Q8_0_RDNA2 128 +#define NWARPS_Q8_0_RDNA2 8 +#define MMQ_X_Q8_0_RDNA1 64 +#define MMQ_Y_Q8_0_RDNA1 64 +#define NWARPS_Q8_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q8_0_AMPERE 4 +#define MMQ_Y_Q8_0_AMPERE 32 +#define NWARPS_Q8_0_AMPERE 4 +#else +#define MMQ_X_Q8_0_AMPERE 128 +#define MMQ_Y_Q8_0_AMPERE 64 +#define NWARPS_Q8_0_AMPERE 4 +#endif +#define MMQ_X_Q8_0_PASCAL 64 +#define MMQ_Y_Q8_0_PASCAL 64 +#define NWARPS_Q8_0_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q8_0_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + mul_mat_q8_0( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q8_0_RDNA2; + const int mmq_y = MMQ_Y_Q8_0_RDNA2; + const int nwarps = NWARPS_Q8_0_RDNA2; +#else + const int mmq_x = MMQ_X_Q8_0_RDNA1; + const int mmq_y = MMQ_Y_Q8_0_RDNA1; + const int nwarps = NWARPS_Q8_0_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q8_0_AMPERE; + const int mmq_y = MMQ_Y_Q8_0_AMPERE; + const int nwarps = NWARPS_Q8_0_AMPERE; + + mul_mat_q, + load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q8_0_PASCAL; + const int mmq_y = MMQ_Y_Q8_0_PASCAL; + const int nwarps = NWARPS_Q8_0_PASCAL; + + mul_mat_q, + load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q8_0_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q2_K_RDNA2 64 +#define MMQ_Y_Q2_K_RDNA2 128 +#define NWARPS_Q2_K_RDNA2 8 +#define MMQ_X_Q2_K_RDNA1 128 +#define MMQ_Y_Q2_K_RDNA1 32 +#define NWARPS_Q2_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q2_K_AMPERE 4 +#define MMQ_Y_Q2_K_AMPERE 32 +#define NWARPS_Q2_K_AMPERE 4 +#else +#define MMQ_X_Q2_K_AMPERE 64 +#define MMQ_Y_Q2_K_AMPERE 128 +#define NWARPS_Q2_K_AMPERE 4 +#endif +#define MMQ_X_Q2_K_PASCAL 64 +#define MMQ_Y_Q2_K_PASCAL 64 +#define NWARPS_Q2_K_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q2_K_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +mul_mat_q2_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q2_K_RDNA2; + const int mmq_y = MMQ_Y_Q2_K_RDNA2; + const int nwarps = NWARPS_Q2_K_RDNA2; +#else + const int mmq_x = MMQ_X_Q2_K_RDNA1; + const int mmq_y = MMQ_Y_Q2_K_RDNA1; + const int nwarps = NWARPS_Q2_K_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q2_K_AMPERE; + const int mmq_y = MMQ_Y_Q2_K_AMPERE; + const int nwarps = NWARPS_Q2_K_AMPERE; + + mul_mat_q, + load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q2_K_PASCAL; + const int mmq_y = MMQ_Y_Q2_K_PASCAL; + const int nwarps = NWARPS_Q2_K_PASCAL; + + mul_mat_q, + load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q2_K_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q3_K_RDNA2 128 +#define MMQ_Y_Q3_K_RDNA2 64 +#define NWARPS_Q3_K_RDNA2 8 +#define MMQ_X_Q3_K_RDNA1 32 +#define MMQ_Y_Q3_K_RDNA1 128 +#define NWARPS_Q3_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q3_K_AMPERE 4 +#define MMQ_Y_Q3_K_AMPERE 32 +#define NWARPS_Q3_K_AMPERE 4 +#else +#define MMQ_X_Q3_K_AMPERE 128 +#define MMQ_Y_Q3_K_AMPERE 128 +#define NWARPS_Q3_K_AMPERE 4 +#endif +#define MMQ_X_Q3_K_PASCAL 64 +#define MMQ_Y_Q3_K_PASCAL 64 +#define NWARPS_Q3_K_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#elif __CUDA_ARCH__ < CC_VOLTA + __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2) +#endif // __CUDA_ARCH__ < CC_VOLTA + mul_mat_q3_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q3_K_RDNA2; + const int mmq_y = MMQ_Y_Q3_K_RDNA2; + const int nwarps = NWARPS_Q3_K_RDNA2; +#else + const int mmq_x = MMQ_X_Q3_K_RDNA1; + const int mmq_y = MMQ_Y_Q3_K_RDNA1; + const int nwarps = NWARPS_Q3_K_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q3_K_AMPERE; + const int mmq_y = MMQ_Y_Q3_K_AMPERE; + const int nwarps = NWARPS_Q3_K_AMPERE; + + mul_mat_q, + load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q3_K_PASCAL; + const int mmq_y = MMQ_Y_Q3_K_PASCAL; + const int nwarps = NWARPS_Q3_K_PASCAL; + + mul_mat_q, + load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q3_K_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q4_K_RDNA2 64 +#define MMQ_Y_Q4_K_RDNA2 128 +#define NWARPS_Q4_K_RDNA2 8 +#define MMQ_X_Q4_K_RDNA1 32 +#define MMQ_Y_Q4_K_RDNA1 64 +#define NWARPS_Q4_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q4_K_AMPERE 4 +#define MMQ_Y_Q4_K_AMPERE 32 +#define NWARPS_Q4_K_AMPERE 4 +#else +#define MMQ_X_Q4_K_AMPERE 64 +#define MMQ_Y_Q4_K_AMPERE 128 +#define NWARPS_Q4_K_AMPERE 4 +#endif +#define MMQ_X_Q4_K_PASCAL 64 +#define MMQ_Y_Q4_K_PASCAL 64 +#define NWARPS_Q4_K_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#elif __CUDA_ARCH__ < CC_VOLTA + __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2) +#endif // __CUDA_ARCH__ < CC_VOLTA + mul_mat_q4_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q4_K_RDNA2; + const int mmq_y = MMQ_Y_Q4_K_RDNA2; + const int nwarps = NWARPS_Q4_K_RDNA2; +#else + const int mmq_x = MMQ_X_Q4_K_RDNA1; + const int mmq_y = MMQ_Y_Q4_K_RDNA1; + const int nwarps = NWARPS_Q4_K_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q4_K_AMPERE; + const int mmq_y = MMQ_Y_Q4_K_AMPERE; + const int nwarps = NWARPS_Q4_K_AMPERE; + + mul_mat_q, + load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q4_K_PASCAL; + const int mmq_y = MMQ_Y_Q4_K_PASCAL; + const int nwarps = NWARPS_Q4_K_PASCAL; + + mul_mat_q, + load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q4_K_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q5_K_RDNA2 64 +#define MMQ_Y_Q5_K_RDNA2 128 +#define NWARPS_Q5_K_RDNA2 8 +#define MMQ_X_Q5_K_RDNA1 32 +#define MMQ_Y_Q5_K_RDNA1 64 +#define NWARPS_Q5_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q5_K_AMPERE 4 +#define MMQ_Y_Q5_K_AMPERE 32 +#define NWARPS_Q5_K_AMPERE 4 +#else +#define MMQ_X_Q5_K_AMPERE 64 +#define MMQ_Y_Q5_K_AMPERE 128 +#define NWARPS_Q5_K_AMPERE 4 +#endif +#define MMQ_X_Q5_K_PASCAL 64 +#define MMQ_Y_Q5_K_PASCAL 64 +#define NWARPS_Q5_K_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q5_K_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +mul_mat_q5_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q5_K_RDNA2; + const int mmq_y = MMQ_Y_Q5_K_RDNA2; + const int nwarps = NWARPS_Q5_K_RDNA2; +#else + const int mmq_x = MMQ_X_Q5_K_RDNA1; + const int mmq_y = MMQ_Y_Q5_K_RDNA1; + const int nwarps = NWARPS_Q5_K_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q5_K_AMPERE; + const int mmq_y = MMQ_Y_Q5_K_AMPERE; + const int nwarps = NWARPS_Q5_K_AMPERE; + + mul_mat_q, + load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q5_K_PASCAL; + const int mmq_y = MMQ_Y_Q5_K_PASCAL; + const int nwarps = NWARPS_Q5_K_PASCAL; + + mul_mat_q, + load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q5_K_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q6_K_RDNA2 64 +#define MMQ_Y_Q6_K_RDNA2 128 +#define NWARPS_Q6_K_RDNA2 8 +#define MMQ_X_Q6_K_RDNA1 32 +#define MMQ_Y_Q6_K_RDNA1 64 +#define NWARPS_Q6_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q6_K_AMPERE 4 +#define MMQ_Y_Q6_K_AMPERE 32 +#define NWARPS_Q6_K_AMPERE 4 +#else +#define MMQ_X_Q6_K_AMPERE 64 +#define MMQ_Y_Q6_K_AMPERE 64 +#define NWARPS_Q6_K_AMPERE 4 +#endif +#define MMQ_X_Q6_K_PASCAL 64 +#define MMQ_Y_Q6_K_PASCAL 64 +#define NWARPS_Q6_K_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#elif __CUDA_ARCH__ < CC_VOLTA + __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2) +#endif // __CUDA_ARCH__ < CC_VOLTA + mul_mat_q6_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q6_K_RDNA2; + const int mmq_y = MMQ_Y_Q6_K_RDNA2; + const int nwarps = NWARPS_Q6_K_RDNA2; +#else + const int mmq_x = MMQ_X_Q6_K_RDNA1; + const int mmq_y = MMQ_Y_Q6_K_RDNA1; + const int nwarps = NWARPS_Q6_K_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q6_K_AMPERE; + const int mmq_y = MMQ_Y_Q6_K_AMPERE; + const int nwarps = NWARPS_Q6_K_AMPERE; + + mul_mat_q, + load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q6_K_PASCAL; + const int mmq_y = MMQ_Y_Q6_K_PASCAL; + const int nwarps = NWARPS_Q6_K_PASCAL; + + mul_mat_q, + load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q6_K_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +template +static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = 0; i < blocks_per_row; i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i + threadIdx.x / (qi/vdr); // x block index + + const int iby = (i + threadIdx.x / (qi/vdr)) * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + +template +static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { + // qk = quantized weights per x block + // qr = number of quantized weights per data value in x block + const int row = blockIdx.x*blockDim.y + threadIdx.y; + + if (row >= nrows) { + return; + } + + const int tid = threadIdx.x; + + const int iter_stride = 2*GGML_CUDA_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter + const int y_offset = qr == 1 ? 1 : qk/2; + +// partial sum for each thread +#ifdef GGML_CUDA_F16 + half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics +#else + float tmp = 0.0f; +#endif // GGML_CUDA_F16 + + for (int i = 0; i < ncols; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index + const int iybs = col - col%qk; // y block start index + +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter + + // dequantize + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + dfloat2 v; + dequantize_kernel(vx, ib, iqs + j/qr, v); + + // matrix multiplication + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 +#ifdef GGML_CUDA_F16 + tmp += __hmul2(v, { + y[iybs + iqs + j/qr + 0], + y[iybs + iqs + j/qr + y_offset] + }); +#else + tmp += v.x * y[iybs + iqs + j/qr + 0]; + tmp += v.y * y[iybs + iqs + j/qr + y_offset]; +#endif // GGML_CUDA_F16 + } + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { +#ifdef GGML_CUDA_F16 + dst[row] = tmp.x + tmp.y; +#else + dst[row] = tmp; +#endif // GGML_CUDA_F16 + } +} + +static __global__ void mul_mat_p021_f16_f32( + const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) { + + const half * x = (const half *) vx; + + const int row_x = blockDim.y*blockIdx.y + threadIdx.y; + const int channel = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = channel / (nchannels_y / nchannels_x); + + const int nrows_y = ncols_x; + const int nrows_dst = nrows_x; + const int row_dst = row_x; + + float tmp = 0.0f; + + for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { + const int col_x = col_x0 + threadIdx.x; + + if (col_x >= ncols_x) { + break; + } + + // x is transposed and permuted + const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x; + const float xi = __half2float(x[ix]); + + const int row_y = col_x; + + + // y is not transposed but permuted + const int iy = channel*nrows_y + row_y; + + tmp += xi * y[iy]; + } + + // dst is not transposed and not permuted + const int idst = channel*nrows_dst + row_dst; + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[idst] = tmp; + } +} + +static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous + const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, + const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) { + + const half * x = (const half *) vx; + + const int row_x = blockDim.y*blockIdx.y + threadIdx.y; + const int channel = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = channel / channel_x_divisor; + + const int nrows_y = ncols_x; + const int nrows_dst = nrows_x; + const int row_dst = row_x; + + const int idst = channel*nrows_dst + row_dst; + + float tmp = 0.0f; + + for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { + const int col_x = col_x0 + threadIdx.x; + + if (col_x >= ncols_x) { + break; + } + + const int row_y = col_x; + + const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; + const int iy = channel*nrows_y + row_y; + + const float xi = __half2float(x[ix]); + + tmp += xi * y[iy]; + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[idst] = tmp; + } +} + +static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + float * dsti = (float *) cdsti; + + *dsti = *xi; +} + +static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + half * dsti = (half *) cdsti; + + *dsti = __float2half(*xi); +} + +static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) { + const half * xi = (const half *) cxi; + half * dsti = (half *) cdsti; + + *dsti = *xi; +} + +template +static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor + // then combine those indices with the corresponding byte offsets to get the total offsets + const int i02 = i / (ne00*ne01); + const int i01 = (i - i02*ne01*ne00) / ne00; + const int i00 = i - i02*ne01*ne00 - i01*ne00; + const int x_offset = i00*nb00 + i01*nb01 + i02*nb02; + + const int i12 = i / (ne10*ne11); + const int i11 = (i - i12*ne10*ne11) / ne10; + const int i10 = i - i12*ne10*ne11 - i11*ne10; + const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12; + + cpy_1(cx + x_offset, cdst + dst_offset); +} + +static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q8_0 * dsti = (block_q8_0 *) cdsti; + + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = xi[j]; + amax = fmaxf(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dsti->d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = xi[j]*id; + + dsti->qs[j] = roundf(x0); + } +} + +static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q4_0 * dsti = (block_q4_0 *) cdsti; + + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK4_0; ++j) { + const float v = xi[j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + vmax = v; + } + } + + const float d = vmax / -8; + const float id = d ? 1.0f/d : 0.0f; + + dsti->d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = xi[0 + j]*id; + const float x1 = xi[QK4_0/2 + j]*id; + + const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f)); + + dsti->qs[j] = xi0; + dsti->qs[j] |= xi1 << 4; + } +} + +static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q4_1 * dsti = (block_q4_1 *) cdsti; + + float vmin = FLT_MAX; + float vmax = -FLT_MAX; + + for (int j = 0; j < QK4_1; ++j) { + const float v = xi[j]; + + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dsti->dm.x = d; + dsti->dm.y = vmin; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (xi[0 + j] - vmin)*id; + const float x1 = (xi[QK4_1/2 + j] - vmin)*id; + + const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f)); + + dsti->qs[j] = xi0; + dsti->qs[j] |= xi1 << 4; + } +} + +template +static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) { + const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; + + if (i >= ne) { + return; + } + + const int i02 = i / (ne00*ne01); + const int i01 = (i - i02*ne01*ne00) / ne00; + const int i00 = (i - i02*ne01*ne00 - i01*ne00); + const int x_offset = i00*nb00 + i01*nb01 + i02*nb02; + + const int i12 = i / (ne10*ne11); + const int i11 = (i - i12*ne10*ne11) / ne10; + const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk; + const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12; + + cpy_blck(cx + x_offset, cdst + dst_offset); +} + +static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +struct rope_corr_dims { + float v[4]; +}; + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static __device__ void rope_yarn( + float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale, + float * cos_theta, float * sin_theta +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + *cos_theta = cosf(theta) * mscale; + *sin_theta = sinf(theta) * mscale; +} + +// rope == RoPE == rotary positional embedding +template +static __global__ void rope( + const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, + float ext_factor, float attn_factor, rope_corr_dims corr_dims +) { + const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (col >= ncols) { + return; + } + + const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int i = row*ncols + col; + const int i2 = row/p_delta_rows; + + const int p = has_pos ? pos[i2] : 0; + const float theta_base = p*powf(freq_base, -float(col)/ncols); + + float cos_theta, sin_theta; + rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta); + + const float x0 = x[i + 0]; + const float x1 = x[i + 1]; + + dst[i + 0] = x0*cos_theta - x1*sin_theta; + dst[i + 1] = x0*sin_theta + x1*cos_theta; +} + +template +static __global__ void rope_neox( + const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims +) { + const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (col >= ncols) { + return; + } + + const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int ib = col / n_dims; + const int ic = col % n_dims; + + const int i = row*ncols + ib*n_dims + ic/2; + const int i2 = row/p_delta_rows; + + float cur_rot = inv_ndims * ic - ib; + + const int p = has_pos ? pos[i2] : 0; + const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f); + + float cos_theta, sin_theta; + rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); + + const float x0 = x[i + 0]; + const float x1 = x[i + n_dims/2]; + + dst[i + 0] = x0*cos_theta - x1*sin_theta; + dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; +} + +static __global__ void rope_glm_f32( + const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, + int n_ctx +) { + const int col = blockDim.x*blockIdx.x + threadIdx.x; + const int half_n_dims = ncols/4; + + if (col >= half_n_dims) { + return; + } + + const int row = blockDim.y*blockIdx.y + threadIdx.y; + const int i = row*ncols + col; + const int i2 = row/p_delta_rows; + + const float col_theta_scale = powf(freq_base, -2.0f*col/ncols); + // FIXME: this is likely wrong + const int p = pos != nullptr ? pos[i2] : 0; + + const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale; + const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta); + + const float x0 = x[i + 0]; + const float x1 = x[i + half_n_dims]; + + dst[i + 0] = x0*cos_theta - x1*sin_theta; + dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; + + const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale; + const float sin_block_theta = sinf(block_theta); + const float cos_block_theta = cosf(block_theta); + + const float x2 = x[i + half_n_dims * 2]; + const float x3 = x[i + half_n_dims * 3]; + + dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta; + dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; +} + +static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows, + const int n_heads_log2_floor, const float m0, const float m1) { + const int col = blockDim.x*blockIdx.x + threadIdx.x; + + if (col >= ncols) { + return; + } + + const int row = blockDim.y*blockIdx.y + threadIdx.y; + const int i = row*ncols + col; + + const int k = row/k_rows; + + float m_k; + if (k < n_heads_log2_floor) { + m_k = powf(m0, k + 1); + } else { + m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); + } + + dst[i] = col * m_k + x[i]; +} + +static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) { + const int row = blockIdx.y; + const int col = threadIdx.x; + + float sum = 0.0f; + for (int i = col; i < ncols; i += blockDim.x) { + sum += x[row * ncols + i]; + } + + sum = warp_reduce_sum(sum); + + if (col == 0) { + dst[row] = sum; + } +} + +template +static inline __device__ void swap(T & a, T & b) { + T tmp = a; + a = b; + b = tmp; +} + +template +static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) { + // bitonic sort + int col = threadIdx.x; + int row = blockIdx.y; + + if (col >= ncols) return; + + const float * x_row = x + row * ncols; + int * dst_row = dst + row * ncols; + + // initialize indices + if (col < ncols) { + dst_row[col] = col; + } + __syncthreads(); + + for (int k = 2; k <= ncols; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + swap(dst_row[col], dst_row[ixj]); + } + } else { + if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + swap(dst_row[col], dst_row[ixj]); + } + } + } + __syncthreads(); + } + } +} + +static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { + const int col = blockDim.y*blockIdx.y + threadIdx.y; + const int row = blockDim.x*blockIdx.x + threadIdx.x; + + if (col >= ncols) { + return; + } + + const int i = row*ncols + col; + //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i]; + //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU + dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; +} + +static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { + const int tid = threadIdx.x; + const int rowx = blockIdx.x; + const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension + + const int block_size = blockDim.x; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; + + float max_val = -INFINITY; + + for (int col = tid; col < ncols; col += block_size) { + const int ix = rowx*ncols + col; + const int iy = rowy*ncols + col; + max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f)); + } + + // find the max value in the block + max_val = warp_reduce_max(max_val); + if (block_size > WARP_SIZE) { + if (warp_id == 0) { + buf[lane_id] = -INFINITY; + } + __syncthreads(); + + if (lane_id == 0) { + buf[warp_id] = max_val; + } + __syncthreads(); + + max_val = buf[lane_id]; + max_val = warp_reduce_max(max_val); + } + + float tmp = 0.f; + + for (int col = tid; col < ncols; col += block_size) { + const int ix = rowx*ncols + col; + const int iy = rowy*ncols + col; + const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val); + tmp += val; + dst[ix] = val; + } + + // find the sum of exps in the block + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + if (warp_id == 0) { + buf[lane_id] = 0.f; + } + __syncthreads(); + + if (lane_id == 0) { + buf[warp_id] = tmp; + } + __syncthreads(); + + tmp = buf[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + const float inv_tmp = 1.f / tmp; + + for (int col = tid; col < ncols; col += block_size) { + const int i = rowx*ncols + col; + dst[i] *= inv_tmp; + } +} + +static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = scale * x[i]; +} + +static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); +} + +static __global__ void im2col_f32_f16( + const float * x, half * dst, + int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW, + int s0, int s1, int p0, int p1, int d0, int d1) { + const int i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= pelements) { + return; + } + + const int ksize = OW * (KH > 1 ? KW : 1); + const int kx = i / ksize; + const int kd = kx * ksize; + const int ky = (i - kd) / OW; + const int ix = i % OW; + + const int iiw = ix * s0 + kx * d0 - p0; + const int iih = blockIdx.y * s1 + ky * d1 - p1; + + const int offset_dst = + (blockIdx.y * OW + ix) * CHW + + (blockIdx.z * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = __float2half(0.0f); + } else { + const int offset_src = blockIdx.z * offset_delta; + dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]); + } +} + +template +static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); + const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); + const dim3 block_nums(block_num_x, ne10, ne11*ne12); + + // strides in elements + //const size_t s0 = nb0 / ggml_element_size(dst); + const size_t s1 = nb1 / ggml_element_size(dst); + const size_t s2 = nb2 / ggml_element_size(dst); + const size_t s3 = nb3 / ggml_element_size(dst); + + const size_t s10 = nb10 / ggml_element_size(src1); + const size_t s11 = nb11 / ggml_element_size(src1); + const size_t s12 = nb12 / ggml_element_size(src1); + //const size_t s13 = nb13 / ggml_element_size(src1); + + GGML_ASSERT(ne00 % 2 == 0); + + k_get_rows<<>>( + src0_dd, src1_dd, dst_dd, + ne00, /*ne01, ne02, ne03,*/ + /*ne10, ne11,*/ ne12, /*ne13,*/ + /* s0,*/ s1, s2, s3, + /* nb00,*/ nb01, nb02, nb03, + s10, s11, s12/*, s13*/); + + (void) dst; +} + +template +static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); + const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; + const dim3 block_nums(block_num_x, ne10, ne11*ne12); + + // strides in elements + //const size_t s0 = nb0 / ggml_element_size(dst); + const size_t s1 = nb1 / ggml_element_size(dst); + const size_t s2 = nb2 / ggml_element_size(dst); + const size_t s3 = nb3 / ggml_element_size(dst); + + const size_t s10 = nb10 / ggml_element_size(src1); + const size_t s11 = nb11 / ggml_element_size(src1); + const size_t s12 = nb12 / ggml_element_size(src1); + //const size_t s13 = nb13 / ggml_element_size(src1); + + k_get_rows_float<<>>( + src0_dd, src1_dd, dst_dd, + ne00, /*ne01, ne02, ne03,*/ + /*ne10, ne11,*/ ne12, /*ne13,*/ + /* s0,*/ s1, s2, s3, + /* nb00,*/ nb01, nb02, nb03, + s10, s11, s12/*, s13*/); + + (void) dst; +} + +template +struct bin_bcast_cuda { + template + void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, + const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, + cudaStream_t stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + int nr0 = ne10/ne0; + int nr1 = ne11/ne1; + int nr2 = ne12/ne2; + int nr3 = ne13/ne3; + + int nr[4] = { nr0, nr1, nr2, nr3 }; + + // collapse dimensions until first broadcast dimension + int64_t cne0[] = {ne0, ne1, ne2, ne3}; + int64_t cne1[] = {ne10, ne11, ne12, ne13}; + size_t cnb0[] = {nb0, nb1, nb2, nb3}; + size_t cnb1[] = {nb10, nb11, nb12, nb13}; + auto collapse = [](int64_t cne[]) { + cne[0] *= cne[1]; + cne[1] = cne[2]; + cne[2] = cne[3]; + cne[3] = 1; + }; + + auto collapse_nb = [](size_t cnb[], int64_t cne[]) { + cnb[1] *= cne[1]; + cnb[2] *= cne[2]; + cnb[3] *= cne[3]; + }; + + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne0); + collapse(cne1); + } + } + { + int64_t ne0 = cne0[0]; + int64_t ne1 = cne0[1]; + int64_t ne2 = cne0[2]; + int64_t ne3 = cne0[3]; + + int64_t ne10 = cne1[0]; + int64_t ne11 = cne1[1]; + int64_t ne12 = cne1[2]; + int64_t ne13 = cne1[3]; + + size_t nb0 = cnb0[0]; + size_t nb1 = cnb0[1]; + size_t nb2 = cnb0[2]; + size_t nb3 = cnb0[3]; + + size_t nb10 = cnb1[0]; + size_t nb11 = cnb1[1]; + size_t nb12 = cnb1[2]; + size_t nb13 = cnb1[3]; + + size_t s0 = nb0 / sizeof(dst_t); + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); + + size_t s10 = nb10 / sizeof(src1_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + GGML_ASSERT(s0 == 1); + GGML_ASSERT(s10 == 1); + + const int block_size = 128; + + int64_t hne0 = std::max(ne0/2LL, 1LL); + + dim3 block_dims; + block_dims.x = std::min(hne0, block_size); + block_dims.y = std::min(ne1, block_size / block_dims.x); + block_dims.z = std::min(std::min(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U); + + dim3 block_nums( + (hne0 + block_dims.x - 1) / block_dims.x, + (ne1 + block_dims.y - 1) / block_dims.y, + (ne2*ne3 + block_dims.z - 1) / block_dims.z + ); + + if (block_nums.z > 65535) { + // this is the maximum number of blocks in z direction, fallback to 1D grid kernel + int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; + k_bin_bcast_unravel<<>>( + src0_dd, src1_dd, dst_dd, + ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s10, */ s11, s12, s13); + } else { + k_bin_bcast<<>>( + src0_dd, src1_dd, dst_dd, + ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s10, */ s11, s12, s13); + } + } + } +}; + +static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements, + const int ne10, const int ne11, const int ne12, + const int nb1, const int nb2, const int offset, cudaStream_t stream) { + int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE; + acc_f32<<>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset); +} + +static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; + gelu_f32<<>>(x, dst, k); +} + +static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + silu_f32<<>>(x, dst, k); +} + +static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; + gelu_quick_f32<<>>(x, dst, k); +} + +static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE; + tanh_f32<<>>(x, dst, k); +} + +static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; + relu_f32<<>>(x, dst, k); +} + +static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) { + const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; + leaky_relu_f32<<>>(x, dst, k, negative_slope); +} + +static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE; + sqr_f32<<>>(x, dst, k); +} + +static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + norm_f32<<>>(x, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + norm_f32<1024><<>>(x, dst, ncols, eps); + } +} + +static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) { + static const float eps = 1e-6f; + if (group_size < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + group_norm_f32<<>>(x, dst, group_size, ne_elements, eps); + } else { + const dim3 block_dims(1024, 1, 1); + group_norm_f32<1024><<>>(x, dst, group_size, ne_elements, eps); + } +} + +static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) { + int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; + dim3 gridDim(num_blocks, ne1, ne2); + concat_f32<<>>(x, y, dst, ne0, ne02); +} + +static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) { + int ne0 = (ne00 * scale_factor); + int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; + dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02); + upscale_f32<<>>(x, dst, ne00, ne00 * ne01, scale_factor); +} + +static void pad_f32_cuda(const float * x, float * dst, + const int ne00, const int ne01, const int ne02, + const int ne0, const int ne1, const int ne2, cudaStream_t stream) { + int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; + dim3 gridDim(num_blocks, ne1, ne2); + pad_f32<<>>(x, dst, ne0, ne00, ne01, ne02); +} + +static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + rms_norm_f32<<>>(x, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32<1024><<>>(x, dst, ncols, eps); + } +} + +static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) { + const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const dim3 num_blocks(block_num_x, ky, 1); + const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1); + quantize_q8_1<<>>(x, vy, kx, kx_padded); +} + +template +static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); +} + +template +static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; +#if QK_K == 256 + dequantize_block_q2_K<<>>(vx, y); +#else + dequantize_block_q2_K<<>>(vx, y); +#endif +} + +template +static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; +#if QK_K == 256 + dequantize_block_q3_K<<>>(vx, y); +#else + dequantize_block_q3_K<<>>(vx, y); +#endif +} + +template +static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q4_K<<>>(vx, y); +} + +template +static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; +#if QK_K == 256 + dequantize_block_q5_K<<>>(vx, y); +#else + dequantize_block_q5_K<<>>(vx, y); +#endif +} + +template +static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; +#if QK_K == 256 + dequantize_block_q6_K<<>>(vx, y); +#else + dequantize_block_q6_K<<>>(vx, y); +#endif +} + +static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_block_cuda; + case GGML_TYPE_Q4_1: + return dequantize_block_cuda; + case GGML_TYPE_Q5_0: + return dequantize_block_cuda; + case GGML_TYPE_Q5_1: + return dequantize_block_cuda; + case GGML_TYPE_Q8_0: + return dequantize_block_cuda; + case GGML_TYPE_Q2_K: + return dequantize_row_q2_K_cuda; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_K_cuda; + case GGML_TYPE_Q4_K: + return dequantize_row_q4_K_cuda; + case GGML_TYPE_Q5_K: + return dequantize_row_q5_K_cuda; + case GGML_TYPE_Q6_K: + return dequantize_row_q6_K_cuda; + case GGML_TYPE_F32: + return dequantize_block_cuda<1, 1, convert_f32>; + default: + return nullptr; + } +} + +static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_block_cuda; + case GGML_TYPE_Q4_1: + return dequantize_block_cuda; + case GGML_TYPE_Q5_0: + return dequantize_block_cuda; + case GGML_TYPE_Q5_1: + return dequantize_block_cuda; + case GGML_TYPE_Q8_0: + return dequantize_block_cuda; + case GGML_TYPE_Q2_K: + return dequantize_row_q2_K_cuda; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_K_cuda; + case GGML_TYPE_Q4_K: + return dequantize_row_q4_K_cuda; + case GGML_TYPE_Q5_K: + return dequantize_row_q5_K_cuda; + case GGML_TYPE_Q6_K: + return dequantize_row_q6_K_cuda; + case GGML_TYPE_F16: + return dequantize_block_cuda<1, 1, convert_f16>; + default: + return nullptr; + } +} + +static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_q2_k<<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const dim3 block_dims(32, 1, 1); + dequantize_mul_mat_vec_q5_k<<>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); +} + +static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec<1, 1, convert_f16> + <<>>(vx, y, dst, ncols, nrows); +} + +static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK4_1 == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK5_0 == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK5_1 == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void ggml_mul_mat_q4_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q4_0_RDNA2; + mmq_y = MMQ_Y_Q4_0_RDNA2; + nwarps = NWARPS_Q4_0_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q4_0_RDNA1; + mmq_y = MMQ_Y_Q4_0_RDNA1; + nwarps = NWARPS_Q4_0_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q4_0_AMPERE; + mmq_y = MMQ_Y_Q4_0_AMPERE; + nwarps = NWARPS_Q4_0_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q4_0_PASCAL; + mmq_y = MMQ_Y_Q4_0_PASCAL; + nwarps = NWARPS_Q4_0_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q4_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q4_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q4_1_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q4_1_RDNA2; + mmq_y = MMQ_Y_Q4_1_RDNA2; + nwarps = NWARPS_Q4_1_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q4_1_RDNA1; + mmq_y = MMQ_Y_Q4_1_RDNA1; + nwarps = NWARPS_Q4_1_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q4_1_AMPERE; + mmq_y = MMQ_Y_Q4_1_AMPERE; + nwarps = NWARPS_Q4_1_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q4_1_PASCAL; + mmq_y = MMQ_Y_Q4_1_PASCAL; + nwarps = NWARPS_Q4_1_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q4_1<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q4_1<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q5_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q5_0_RDNA2; + mmq_y = MMQ_Y_Q5_0_RDNA2; + nwarps = NWARPS_Q5_0_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q5_0_RDNA1; + mmq_y = MMQ_Y_Q5_0_RDNA1; + nwarps = NWARPS_Q5_0_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q5_0_AMPERE; + mmq_y = MMQ_Y_Q5_0_AMPERE; + nwarps = NWARPS_Q5_0_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q5_0_PASCAL; + mmq_y = MMQ_Y_Q5_0_PASCAL; + nwarps = NWARPS_Q5_0_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q5_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q5_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q5_1_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q5_1_RDNA2; + mmq_y = MMQ_Y_Q5_1_RDNA2; + nwarps = NWARPS_Q5_1_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q5_1_RDNA1; + mmq_y = MMQ_Y_Q5_1_RDNA1; + nwarps = NWARPS_Q5_1_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q5_1_AMPERE; + mmq_y = MMQ_Y_Q5_1_AMPERE; + nwarps = NWARPS_Q5_1_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q5_1_PASCAL; + mmq_y = MMQ_Y_Q5_1_PASCAL; + nwarps = NWARPS_Q5_1_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q5_1<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q5_1<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q8_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q8_0_RDNA2; + mmq_y = MMQ_Y_Q8_0_RDNA2; + nwarps = NWARPS_Q8_0_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q8_0_RDNA1; + mmq_y = MMQ_Y_Q8_0_RDNA1; + nwarps = NWARPS_Q8_0_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q8_0_AMPERE; + mmq_y = MMQ_Y_Q8_0_AMPERE; + nwarps = NWARPS_Q8_0_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q8_0_PASCAL; + mmq_y = MMQ_Y_Q8_0_PASCAL; + nwarps = NWARPS_Q8_0_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q8_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q8_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q2_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q2_K_RDNA2; + mmq_y = MMQ_Y_Q2_K_RDNA2; + nwarps = NWARPS_Q2_K_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q2_K_RDNA1; + mmq_y = MMQ_Y_Q2_K_RDNA1; + nwarps = NWARPS_Q2_K_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q2_K_AMPERE; + mmq_y = MMQ_Y_Q2_K_AMPERE; + nwarps = NWARPS_Q2_K_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q2_K_PASCAL; + mmq_y = MMQ_Y_Q2_K_PASCAL; + nwarps = NWARPS_Q2_K_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q2_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q2_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q3_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + +#if QK_K == 256 + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q3_K_RDNA2; + mmq_y = MMQ_Y_Q3_K_RDNA2; + nwarps = NWARPS_Q3_K_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q3_K_RDNA1; + mmq_y = MMQ_Y_Q3_K_RDNA1; + nwarps = NWARPS_Q3_K_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q3_K_AMPERE; + mmq_y = MMQ_Y_Q3_K_AMPERE; + nwarps = NWARPS_Q3_K_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q3_K_PASCAL; + mmq_y = MMQ_Y_Q3_K_PASCAL; + nwarps = NWARPS_Q3_K_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q3_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q3_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +#endif +} + +static void ggml_mul_mat_q4_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q4_K_RDNA2; + mmq_y = MMQ_Y_Q4_K_RDNA2; + nwarps = NWARPS_Q4_K_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q4_K_RDNA1; + mmq_y = MMQ_Y_Q4_K_RDNA1; + nwarps = NWARPS_Q4_K_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q4_K_AMPERE; + mmq_y = MMQ_Y_Q4_K_AMPERE; + nwarps = NWARPS_Q4_K_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q4_K_PASCAL; + mmq_y = MMQ_Y_Q4_K_PASCAL; + nwarps = NWARPS_Q4_K_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q4_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q4_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q5_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q5_K_RDNA2; + mmq_y = MMQ_Y_Q5_K_RDNA2; + nwarps = NWARPS_Q5_K_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q5_K_RDNA1; + mmq_y = MMQ_Y_Q5_K_RDNA1; + nwarps = NWARPS_Q5_K_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q5_K_AMPERE; + mmq_y = MMQ_Y_Q5_K_AMPERE; + nwarps = NWARPS_Q5_K_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q5_K_PASCAL; + mmq_y = MMQ_Y_Q5_K_PASCAL; + nwarps = NWARPS_Q5_K_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q5_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q5_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q6_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q6_K_RDNA2; + mmq_y = MMQ_Y_Q6_K_RDNA2; + nwarps = NWARPS_Q6_K_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q6_K_RDNA1; + mmq_y = MMQ_Y_Q6_K_RDNA1; + nwarps = NWARPS_Q6_K_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q6_K_AMPERE; + mmq_y = MMQ_Y_Q6_K_AMPERE; + nwarps = NWARPS_Q6_K_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q6_K_PASCAL; + mmq_y = MMQ_Y_Q6_K_PASCAL; + nwarps = NWARPS_Q6_K_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q6_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q6_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_p021_f16_f32_cuda( + const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, + const int nchannels_x, const int nchannels_y, cudaStream_t stream) { + + const dim3 block_nums(1, nrows_x, nchannels_y); + const dim3 block_dims(WARP_SIZE, 1, 1); + mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y); +} + +static void ggml_mul_mat_vec_nc_f16_f32_cuda( + const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, + const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) { + + const dim3 block_nums(1, nrows_x, nchannels_y); + const dim3 block_dims(WARP_SIZE, 1, 1); + mul_mat_vec_nc_f16_f32<<>> + (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x); +} + +static void ggml_cpy_f32_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_f32_f16<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + +static void ggml_cpy_f32_f16_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_f32_f16<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + +static void ggml_cpy_f32_q8_0_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + GGML_ASSERT(ne % QK8_0 == 0); + const int num_blocks = ne / QK8_0; + cpy_f32_q<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + +static void ggml_cpy_f32_q4_0_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + GGML_ASSERT(ne % QK4_0 == 0); + const int num_blocks = ne / QK4_0; + cpy_f32_q<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + +static void ggml_cpy_f32_q4_1_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + GGML_ASSERT(ne % QK4_1 == 0); + const int num_blocks = ne / QK4_1; + cpy_f32_q<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + +static void ggml_cpy_f16_f16_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_f32_f16<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + +static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; + scale_f32<<>>(x, dst, scale, k); +} + +static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE; + clamp_f32<<>>(x, dst, min, max, k); +} + +template +static void rope_cuda( + const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream +) { + GGML_ASSERT(ncols % 2 == 0); + const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); + const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nrows, num_blocks_x, 1); + if (pos == nullptr) { + rope<<>>( + x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims + ); + } else { + rope<<>>( + x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims + ); + } +} + +template +static void rope_neox_cuda( + const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream +) { + GGML_ASSERT(ncols % 2 == 0); + const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); + const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nrows, num_blocks_x, 1); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float inv_ndims = -1.0f / n_dims; + + if (pos == nullptr) { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims + ); + } else { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims + ); + } +} + +static void rope_glm_f32_cuda( + const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, int n_ctx, cudaStream_t stream +) { + GGML_ASSERT(ncols % 4 == 0); + const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1); + const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE; + const dim3 block_nums(num_blocks_x, nrows, 1); + rope_glm_f32<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx); +} + +static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, + const int k_rows, const int n_heads_log2_floor, const float m0, + const float m1, cudaStream_t stream) { + const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1); + const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE); + const dim3 block_nums(num_blocks_x, nrows, 1); + alibi_f32<<>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1); +} + +static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(1, nrows, 1); + k_sum_rows_f32<<>>(x, dst, ncols); +} + +static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { + // bitonic sort requires ncols to be power of 2 + GGML_ASSERT((ncols & (ncols - 1)) == 0); + + const dim3 block_dims(ncols, 1, 1); + const dim3 block_nums(1, nrows, 1); + if (order == GGML_SORT_ASC) { + k_argsort_f32_i32<<>>(x, dst, ncols); + } else if (order == GGML_SORT_DESC) { + k_argsort_f32_i32<<>>(x, dst, ncols); + } else { + GGML_ASSERT(false); + } +} + +static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { + const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1); + const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; + const dim3 block_nums(nrows_x, block_num_x, 1); + diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); +} + +static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { + int nth = WARP_SIZE; + while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; + const dim3 block_dims(nth, 1, 1); + const dim3 block_nums(nrows_x, 1, 1); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); +} + +static void im2col_f32_f16_cuda(const float* x, half* dst, + int IW, int IH, int OW, int OH, int KW, int KH, int IC, + int offset_delta, + int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { + const int parallel_elements = OW * KW * KH; + const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; + dim3 block_nums(num_blocks, OH, IC); + im2col_f32_f16<<>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); +} + +// buffer pool for cuda +#define MAX_CUDA_BUFFERS 256 + +struct scoped_spin_lock { + std::atomic_flag& lock; + scoped_spin_lock(std::atomic_flag& lock) : lock(lock) { + while (lock.test_and_set(std::memory_order_acquire)) { + ; // spin + } + } + ~scoped_spin_lock() { + lock.clear(std::memory_order_release); + } + scoped_spin_lock(const scoped_spin_lock&) = delete; + scoped_spin_lock& operator=(const scoped_spin_lock&) = delete; +}; + +struct cuda_buffer { + void * ptr = nullptr; + size_t size = 0; +}; + +static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS]; +static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; + +static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { + scoped_spin_lock lock(g_cuda_pool_lock); + int id; + CUDA_CHECK(cudaGetDevice(&id)); +#ifdef DEBUG_CUDA_MALLOC + int nnz = 0; + size_t max_size = 0, tot_size = 0; +#endif + size_t best_diff = 1ull << 36; + int ibest = -1; + for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { + cuda_buffer& b = g_cuda_buffer_pool[id][i]; + if (b.ptr != nullptr) { +#ifdef DEBUG_CUDA_MALLOC + ++nnz; + tot_size += b.size; + if (b.size > max_size) max_size = b.size; +#endif + if (b.size >= size) { + size_t diff = b.size - size; + if (diff < best_diff) { + best_diff = diff; + ibest = i; + if (!best_diff) { + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + } + } + } + } + if (ibest >= 0) { + cuda_buffer& b = g_cuda_buffer_pool[id][ibest]; + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } +#ifdef DEBUG_CUDA_MALLOC + fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz, + (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024)); +#endif + void * ptr; + size_t look_ahead_size = (size_t) (1.05 * size); + look_ahead_size = 256 * ((look_ahead_size + 255)/256); + CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size)); + *actual_size = look_ahead_size; + return ptr; +} + +static void ggml_cuda_pool_free(void * ptr, size_t size) { + scoped_spin_lock lock(g_cuda_pool_lock); + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { + cuda_buffer& b = g_cuda_buffer_pool[id][i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); + CUDA_CHECK(cudaFree(ptr)); +} + +static bool g_cublas_loaded = false; + +bool ggml_cublas_loaded(void) { + return g_cublas_loaded; +} + +void ggml_init_cublas() { + static bool initialized = false; + + if (!initialized) { + +#ifdef __HIP_PLATFORM_AMD__ + // Workaround for a rocBLAS bug when using multiple graphics cards: + // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346 + rocblas_initialize(); + CUDA_CHECK(cudaDeviceSynchronize()); +#endif + + if (cudaGetDeviceCount(&g_device_count) != cudaSuccess) { + initialized = true; + g_cublas_loaded = false; + return; + } + + GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); + int64_t total_vram = 0; +#if defined(GGML_CUDA_FORCE_MMQ) + fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); +#else + fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); +#endif +#if defined(CUDA_USE_TENSOR_CORES) + fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: yes\n", __func__); +#else + fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: no\n", __func__); +#endif + fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count); + for (int id = 0; id < g_device_count; ++id) { + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); + fprintf(stderr, " Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor); + + g_tensor_split[id] = total_vram; + total_vram += prop.totalGlobalMem; +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + g_compute_capabilities[id] = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD; +#else + g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + } + for (int id = 0; id < g_device_count; ++id) { + g_tensor_split[id] /= total_vram; + } + + for (int id = 0; id < g_device_count; ++id) { + CUDA_CHECK(ggml_cuda_set_device(id)); + + // create cuda streams + for (int is = 0; is < MAX_STREAMS; ++is) { + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[id][is], cudaStreamNonBlocking)); + } + + // create cublas handle + CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id])); + CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH)); + } + + // configure logging to stdout + // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); + + initialized = true; + g_cublas_loaded = true; + } +} + +void ggml_cuda_set_tensor_split(const float * tensor_split) { + if (tensor_split == nullptr) { + return; + } + bool all_zero = true; + for (int i = 0; i < g_device_count; ++i) { + if (tensor_split[i] != 0.0f) { + all_zero = false; + break; + } + } + if (all_zero) { + return; + } + float split_sum = 0.0f; + for (int i = 0; i < g_device_count; ++i) { + g_tensor_split[i] = split_sum; + split_sum += tensor_split[i]; + } + for (int i = 0; i < g_device_count; ++i) { + g_tensor_split[i] /= split_sum; + } +} + +void * ggml_cuda_host_malloc(size_t size) { + if (getenv("GGML_CUDA_NO_PINNED") != nullptr) { + return nullptr; + } + + void * ptr = nullptr; + cudaError_t err = cudaMallocHost((void **) &ptr, size); + if (err != cudaSuccess) { + // The allocation error can be bypassed. A null ptr will assigned out of this function. + // This can fixed the OOM error in WSL. + cudaGetLastError(); + fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n", + size/1024.0/1024.0, cudaGetErrorString(err)); + return nullptr; + } + + return ptr; +} + +void ggml_cuda_host_free(void * ptr) { + CUDA_CHECK(cudaFreeHost(ptr)); +} + +static cudaError_t ggml_cuda_cpy_tensor_2d( + void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { + + cudaMemcpyKind kind; + char * src_ptr; + if (src->backend == GGML_BACKEND_CPU) { + kind = cudaMemcpyHostToDevice; + src_ptr = (char *) src->data; + } else if (src->backend == GGML_BACKEND_GPU || src->backend == GGML_BACKEND_GPU_SPLIT) { + GGML_ASSERT(src->backend != GGML_BACKEND_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1])); + kind = cudaMemcpyDeviceToDevice; + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra; + int id; + CUDA_CHECK(cudaGetDevice(&id)); + src_ptr = (char *) extra->data_device[id]; + } else { + GGML_ASSERT(false); + } + char * dst_ptr = (char *) dst; + + const int64_t ne0 = src->ne[0]; + const int64_t nb0 = src->nb[0]; + const int64_t nb1 = src->nb[1]; + const int64_t nb2 = src->nb[2]; + const int64_t nb3 = src->nb[3]; + const enum ggml_type type = src->type; + const int64_t ts = ggml_type_size(type); + const int64_t bs = ggml_blck_size(type); + int64_t i1_diff = i1_high - i1_low; + + const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; + if (nb0 == ts && nb1 == ts*ne0/bs) { + return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream); + } else if (nb0 == ts) { + return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream); + } else { + for (int64_t i1 = 0; i1 < i1_diff; i1++) { + const void * rx = (const void *) ((const char *) x + i1*nb1); + void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); + // pretend the row is a matrix with cols=1 + cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream); + if (r != cudaSuccess) return r; + } + return cudaSuccess; + } +} + +static void ggml_cuda_op_get_rows( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) { + + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); + GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); + + const int32_t * src1_i32 = (const int32_t *) src1_d; + + switch (src0->type) { + case GGML_TYPE_F16: + get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream); + break; + case GGML_TYPE_F32: + get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + case GGML_TYPE_Q4_0: + get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + case GGML_TYPE_Q4_1: + get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + case GGML_TYPE_Q5_0: + get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + case GGML_TYPE_Q5_1: + get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + case GGML_TYPE_Q8_0: + get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + default: + // TODO: k-quants + GGML_ASSERT(false); + break; + } +} + +template +inline void ggml_cuda_op_bin_bcast( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + op()(src0, src1, dst, (const half *) src0_dd, src1_dd, (half *) dst_dd, main_stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, (const half *) src0_dd, src1_dd, dst_dd, main_stream); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ASSERT(false); + } +} + +static void ggml_cuda_op_repeat( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & main_stream) { + + ggml_cuda_op_bin_bcast>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream); + + (void) src1; + (void) src1_d; +} + +inline void ggml_cuda_op_add( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + ggml_cuda_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); +} + +inline void ggml_cuda_op_acc( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported + + int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 + int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 + // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused + int offset = dst->op_params[3] / 4; // offset in bytes + + acc_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream); + + (void) dst; +} + +inline void ggml_cuda_op_mul( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + ggml_cuda_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); +} + +inline void ggml_cuda_op_div( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + ggml_cuda_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); +} + +inline void ggml_cuda_op_gelu( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + gelu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_silu( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + silu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_gelu_quick( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + gelu_quick_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_tanh( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + tanh_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_relu( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_leaky_relu( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + float negative_slope; + memcpy(&negative_slope, dst->op_params, sizeof(float)); + + leaky_relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_sqr( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sqr_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_norm( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, eps, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + + +inline void ggml_cuda_op_group_norm( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int num_groups = dst->op_params[0]; + int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); + group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_concat( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_cuda(src0_dd + i3 * (src0->nb[3] / 4), src1_dd + i3 * (src1->nb[3] / 4), dst_dd + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], main_stream); + } + + (void) src1; + (void) dst; +} + +inline void ggml_cuda_op_upscale( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + + const int scale_factor = dst->op_params[0]; + + upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream); + + (void) src1; + (void) dst; +} + +inline void ggml_cuda_op_pad( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + + pad_f32_cuda(src0_dd, dst_dd, + src0->ne[0], src0->ne[1], src0->ne[2], + dst->ne[0], dst->ne[1], dst->ne[2], main_stream); + + (void) src1; + (void) dst; +} + +inline void ggml_cuda_op_rms_norm( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + rms_norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, eps, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_mul_mat_q( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, const cudaStream_t & stream) { + + const int64_t ne00 = src0->ne[0]; + + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne10 % QK8_1 == 0); + + const int64_t ne0 = dst->ne[0]; + + const int64_t row_diff = row_high - row_low; + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into + const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; + + switch (src0->type) { + case GGML_TYPE_Q4_0: + ggml_mul_mat_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q4_1: + ggml_mul_mat_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q5_0: + ggml_mul_mat_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q5_1: + ggml_mul_mat_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q8_0: + ggml_mul_mat_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q2_K: + ggml_mul_mat_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q3_K: + ggml_mul_mat_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q4_K: + ggml_mul_mat_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q5_K: + ggml_mul_mat_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q6_K: + ggml_mul_mat_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + default: + GGML_ASSERT(false); + break; + } + + (void) src1; + (void) dst; + (void) src1_ddf_i; +} + +static int64_t get_row_rounding(ggml_type type) { + int64_t min_compute_capability = INT_MAX; + int64_t max_compute_capability = INT_MIN; + for (int64_t id = 0; id < g_device_count; ++id) { + if (g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { + if (min_compute_capability > g_compute_capabilities[id]) { + min_compute_capability = g_compute_capabilities[id]; + } + if (max_compute_capability < g_compute_capabilities[id]) { + max_compute_capability = g_compute_capabilities[id]; + } + } + } + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + switch(type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + return max_compute_capability >= CC_RDNA2 ? 128 : 64; + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return 1; + case GGML_TYPE_Q2_K: + return max_compute_capability >= CC_RDNA2 ? 128 : 32; + case GGML_TYPE_Q3_K: + return min_compute_capability < CC_RDNA2 ? 128 : 64; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + return max_compute_capability >= CC_RDNA2 ? 128 : 64; + default: + GGML_ASSERT(false); + } +#else + switch(type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return max_compute_capability >= CC_VOLTA ? 128 : 64; + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + return 64; + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return 1; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return max_compute_capability >= CC_VOLTA ? 128 : 64; + case GGML_TYPE_Q6_K: + return 64; + default: + GGML_ASSERT(false); + } +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +} + +inline void ggml_cuda_op_mul_mat_vec_q( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, const cudaStream_t & stream) { + + GGML_ASSERT(ggml_nrows(src1) == 1); + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + switch (src0->type) { + case GGML_TYPE_Q4_0: + mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_K: + mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_K: + mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q6_K: + mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + default: + GGML_ASSERT(false); + break; + } + + (void) src1; + (void) dst; + (void) src1_ddf_i; + (void) src1_ncols; + (void) src1_padded_row_size; +} + +inline void ggml_cuda_op_dequantize_mul_mat_vec( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, const cudaStream_t & stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics +#ifdef GGML_CUDA_F16 + size_t ash; + dfloat * src1_dfloat = nullptr; // dfloat == half + + bool src1_convert_f16 = + src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || + src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || + src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; + + if (src1_convert_f16) { + src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash); + ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00, + ne00, 1, sizeof(float), 0, 0, + ne00, 1, sizeof(half), 0, 0, stream); + } +#else + const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion +#endif // GGML_CUDA_F16 + + switch (src0->type) { + case GGML_TYPE_Q4_0: + dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_1: + dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_0: + dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_1: + dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q8_0: + dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q2_K: + dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q3_K: + dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_K: + dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_K: + dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q6_K: + dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_F16: + convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + default: + GGML_ASSERT(false); + break; + } + +#ifdef GGML_CUDA_F16 + if (src1_convert_f16) { + ggml_cuda_pool_free(src1_dfloat, ash); + } +#endif // GGML_CUDA_F16 + + (void) src1; + (void) dst; + (void) src1_ddq_i; + (void) src1_ncols; + (void) src1_padded_row_size; +} + +inline void ggml_cuda_op_mul_mat_cublas( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, const cudaStream_t & stream) { + + GGML_ASSERT(src0_dd_i != nullptr); + GGML_ASSERT(src1_ddf_i != nullptr); + GGML_ASSERT(dst_dd_i != nullptr); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = dst->ne[0]; + + const int64_t row_diff = row_high - row_low; + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + // the main device has a larger memory buffer to hold the results from all GPUs + // ldc == nrows of the matrix that cuBLAS writes into + int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; + + const int compute_capability = g_compute_capabilities[id]; + + if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { + // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 + half * src0_as_f16 = nullptr; + size_t src0_as = 0; + if (src0->type != GGML_TYPE_F16) { + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + size_t ne = row_diff*ne00; + src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as); + to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream); + } + const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16; + + half * src1_as_f16 = nullptr; + size_t src1_as = 0; + if (src1->type != GGML_TYPE_F16) { + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + size_t ne = src1_ncols*ne10; + src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); + to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); + } + const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16; + size_t dst_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); + CUBLAS_CHECK( + cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f16, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f16, dst_f16, CUDA_R_16F, ldc, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); + + ggml_cuda_pool_free(dst_f16, dst_as); + + if (src0_as != 0) { + ggml_cuda_pool_free(src0_as_f16, src0_as); + } + + if (src1_as != 0) { + ggml_cuda_pool_free(src1_as_f16, src1_as); + } + } + else { + float * src0_ddq_as_f32 = nullptr; + size_t src0_as = 0; + + if (src0->type != GGML_TYPE_F32) { + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); + GGML_ASSERT(to_fp32_cuda != nullptr); + src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT + to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream); + } + const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32; + + const float alpha = 1.0f; + const float beta = 0.0f; + + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); + CUBLAS_CHECK( + cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha, src0_ddf_i, ne00, + src1_ddf_i, ne10, + &beta, dst_dd_i, ldc)); + + if (src0_as != 0) { + ggml_cuda_pool_free(src0_ddq_as_f32, src0_as); + } + } + + (void) dst; + (void) src1_ddq_i; + (void) src1_padded_row_size; +} + +inline void ggml_cuda_op_rope( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t nrows = ggml_nrows(src0); + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + + // RoPE alteration for extended context + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + + const int32_t * pos = nullptr; + if ((mode & 1) == 0) { + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(src1->ne[0] == ne2); + pos = (const int32_t *) src1_dd; + } + + const bool is_neox = mode & 2; + const bool is_glm = mode & 4; + + rope_corr_dims corr_dims; + ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v); + + // compute + if (is_glm) { + GGML_ASSERT(false); + rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream); + } else if (is_neox) { + if (src0->type == GGML_TYPE_F32) { + rope_neox_cuda( + (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, main_stream + ); + } else if (src0->type == GGML_TYPE_F16) { + rope_neox_cuda( + (const half *)src0_dd, (half *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, main_stream + ); + } else { + GGML_ASSERT(false); + } + } else { + if (src0->type == GGML_TYPE_F32) { + rope_cuda( + (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, main_stream + ); + } else if (src0->type == GGML_TYPE_F16) { + rope_cuda( + (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, main_stream + ); + } else { + GGML_ASSERT(false); + } + } + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_alibi( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t nrows = ggml_nrows(src0); + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); + + //GGML_ASSERT(ne01 + n_past == ne00); + GGML_ASSERT(n_head == ne02); + + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); + + alibi_f32_cuda(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream); + + (void) src1; + (void) src1_dd; +} + +inline void ggml_cuda_op_im2col( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16); + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; + + const int64_t IC = src1->ne[is_2D ? 2 : 1]; + const int64_t IH = is_2D ? src1->ne[1] : 1; + const int64_t IW = src1->ne[0]; + + const int64_t KH = is_2D ? src0->ne[1] : 1; + const int64_t KW = src0->ne[0]; + + const int64_t OH = is_2D ? dst->ne[2] : 1; + const int64_t OW = dst->ne[1]; + + const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + + im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + + (void) src0; + (void) src0_dd; +} + + +inline void ggml_cuda_op_sum_rows( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + sum_rows_f32_cuda(src0_dd, dst_dd, ncols, nrows, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_argsort( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + + argsort_f32_i32_cuda(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_diag_mask_inf( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int nrows0 = ggml_nrows(src0); + + const int n_past = ((int32_t *) dst->op_params)[0]; + + diag_mask_inf_f32_cuda(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_soft_max( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1; + + float scale = 1.0f; + memcpy(&scale, dst->op_params, sizeof(float)); + + soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + + (void) dst; +} + +inline void ggml_cuda_op_scale( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + float scale; + // HACK: support for ggml backend interface + if (src1->backend == GGML_BACKEND_CPU) { + scale = ((float *) src1->data)[0]; + } else { + // TODO: pass pointer to kernel instead of copying to host + CUDA_CHECK(cudaMemcpy(&scale, src1->data, sizeof(float), cudaMemcpyDeviceToHost)); + } + + scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream); + CUDA_CHECK(cudaGetLastError()); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_clamp( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + float min; + float max; + memcpy(&min, dst->op_params, sizeof(float)); + memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); + + clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream); + CUDA_CHECK(cudaGetLastError()); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) { + const int64_t nrows0 = ggml_nrows(src0); + + const bool use_src1 = src1 != nullptr; + const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1; + + GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT( dst->backend != GGML_BACKEND_GPU_SPLIT); + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + + const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; + const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU; + const bool dst_on_device = dst->backend == GGML_BACKEND_GPU; + + const bool src1_stays_on_host = use_src1 && dst->op == GGML_OP_SCALE; + + // dd = data device + float * src0_ddf = nullptr; + float * src1_ddf = nullptr; + float * dst_ddf = nullptr; + + // as = actual size + size_t src0_asf = 0; + size_t src1_asf = 0; + size_t dst_asf = 0; + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + if (src0_on_device) { + src0_ddf = (float *) src0_extra->data_device[g_main_device]; + } else { + src0_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_asf); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream)); + } + + if (use_src1 && !src1_stays_on_host) { + if (src1_on_device) { + src1_ddf = (float *) src1_extra->data_device[g_main_device]; + } else { + src1_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream)); + } + } + if (dst_on_device) { + dst_ddf = (float *) dst_extra->data_device[g_main_device]; + } else { + dst_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(dst), &dst_asf); + } + + // do the computation + op(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream); + CUDA_CHECK(cudaGetLastError()); + + // copy dst to host if necessary + if (!dst_on_device) { + CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream)); + } + + if (src0_asf > 0) { + ggml_cuda_pool_free(src0_ddf, src0_asf); + } + if (src1_asf > 0) { + ggml_cuda_pool_free(src1_ddf, src1_asf); + } + if (dst_asf > 0) { + ggml_cuda_pool_free(dst_ddf, dst_asf); + } + + if (dst->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(cudaDeviceSynchronize()); + } +} + +static void ggml_cuda_set_peer_access(const int n_tokens) { + static bool peer_access_enabled = false; + + const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE; + + if (peer_access_enabled == enable_peer_access) { + return; + } + +#ifdef NDEBUG + for (int id = 0; id < g_device_count; ++id) { + CUDA_CHECK(ggml_cuda_set_device(id)); + + for (int id_other = 0; id_other < g_device_count; ++id_other) { + if (id == id_other) { + continue; + } + if (id != g_main_device && id_other != g_main_device) { + continue; + } + + int can_access_peer; + CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); + if (can_access_peer) { + if (enable_peer_access) { + CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0)); + } else { + CUDA_CHECK(cudaDeviceDisablePeerAccess(id_other)); + } + } + } + } +#endif // NDEBUG + + peer_access_enabled = enable_peer_access; +} + +static void ggml_cuda_op_mul_mat( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, + const bool convert_src1_to_q8_1) { + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + const int64_t nrows0 = ggml_nrows(src0); + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + const int64_t nrows1 = ggml_nrows(src1); + + GGML_ASSERT(ne03 == ne13); + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + ggml_cuda_set_peer_access(ne11); + + GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT); + + GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0); + + const int64_t i02_divisor = ne12 / ne02; + + const size_t src0_ts = ggml_type_size(src0->type); + const size_t src0_bs = ggml_blck_size(src0->type); + const size_t q8_1_ts = sizeof(block_q8_1); + const size_t q8_1_bs = QK8_1; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + + const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; + const bool src0_is_contiguous = ggml_is_contiguous(src0); + const bool src1_is_contiguous = ggml_is_contiguous(src1); + + const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING); + + const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; + GGML_ASSERT(!(split && ne02 > 1)); + GGML_ASSERT(!(split && ne03 > 1)); + GGML_ASSERT(!(split && ne02 < ne12)); + + // dd = data device + char * src0_dd[GGML_CUDA_MAX_DEVICES] = {nullptr}; + float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float + char * src1_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // q8_1 + float * dst_dd[GGML_CUDA_MAX_DEVICES] = {nullptr}; + + // as = actual size + size_t src0_as[GGML_CUDA_MAX_DEVICES] = {0}; + size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0}; + size_t src1_asq[GGML_CUDA_MAX_DEVICES] = {0}; + size_t dst_as[GGML_CUDA_MAX_DEVICES] = {0}; + + int64_t row_low[GGML_CUDA_MAX_DEVICES]; + int64_t row_high[GGML_CUDA_MAX_DEVICES]; + + int used_devices = 0; + + for (int64_t id = 0; id < g_device_count; ++id) { + // by default, use all rows + row_low[id] = 0; + row_high[id] = ne01; + + // for multi GPU, get the row boundaries from tensor split + // and round to mul_mat_q tile sizes + if (split) { + const int64_t rounding = get_row_rounding(src0->type); + + if (id != 0) { + row_low[id] = ne01*g_tensor_split[id]; + row_low[id] -= row_low[id] % rounding; + } + + if (id != g_device_count - 1) { + row_high[id] = ne01*g_tensor_split[id + 1]; + row_high[id] -= row_high[id] % rounding; + } + } + } + + for (int64_t id = 0; id < g_device_count; ++id) { + if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { + continue; + } + + used_devices++; + + const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device; + const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; + + ggml_cuda_set_device(id); + const cudaStream_t stream = g_cudaStreams[id][0]; + + if (src0_on_device && src0_is_contiguous) { + src0_dd[id] = (char *) src0_extra->data_device[id]; + } else { + // const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); + src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]); + } + + if (src1_on_device && src1_is_contiguous) { + src1_ddf[id] = (float *) src1_extra->data_device[id]; + } else { + src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]); + } + + if (convert_src1_to_q8_1) { + src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]); + + if (src1_on_device && src1_is_contiguous) { + quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + } + } + + if (dst_on_device) { + dst_dd[id] = (float *) dst_extra->data_device[id]; + } else { + const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst); + dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]); + } + } + + // if multiple devices are used they need to wait for the main device + // here an event is recorded that signals that the main device has finished calculating the input data + if (split && used_devices > 1) { + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0])); + } + + const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11; + for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) { + const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0; + const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; + + for (int64_t id = 0; id < g_device_count; ++id) { + if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { + continue; + } + + const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device; + const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; + const int64_t row_diff = row_high[id] - row_low[id]; + + ggml_cuda_set_device(id); + const cudaStream_t stream = g_cudaStreams[id][is]; + + // wait for main GPU data if necessary + if (split && (id != g_main_device || is != 0)) { + CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0], 0)); + } + + for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) { + const int64_t i03 = i0 / ne12; + const int64_t i02 = i0 % ne12; + + const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs; + + // for split tensors the data begins at i0 == i0_offset_low + char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; + float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10; + char * src1_ddq_i = src1_ddq[id] + src1_ddq_i_offset; + float * dst_dd_i = dst_dd[id] + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); + + // the main device memory buffer can be on VRAM scratch, with space for all partial results + // in that case an offset on dst_ddf_i is needed + if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) { + dst_dd_i += row_low[id]; // offset is 0 if no tensor split + } + + // copy src0, src1 to device if necessary + if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) { + if (id != g_main_device) { + if (convert_src1_to_q8_1) { + char * src1_ddq_i_source = src1_ddq[g_main_device] + src1_ddq_i_offset; + CUDA_CHECK(cudaMemcpyAsync(src1_ddq_i, src1_ddq_i_source, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, + cudaMemcpyDeviceToDevice, stream)); + } else { + float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device]; + src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10; + CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_ncols*ne10*sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + } + } else if (src1->backend == GGML_BACKEND_CPU || (src1_on_device && !src1_is_contiguous)) { + CUDA_CHECK(ggml_cuda_cpy_tensor_2d( + src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); + } else { + GGML_ASSERT(false); + } + + if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_CPU || !src1_is_contiguous)) { + quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + } + + if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) { + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, row_low[id], row_high[id], stream)); + } + + // do the computation + op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, + row_low[id], row_high[id], src1_ncols, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + + // copy dst to host or other device if necessary + if (!dst_on_device) { + void * dst_off_device; + cudaMemcpyKind kind; + if (dst->backend == GGML_BACKEND_CPU) { + dst_off_device = dst->data; + kind = cudaMemcpyDeviceToHost; + } else if (dst->backend == GGML_BACKEND_GPU) { + dst_off_device = dst_extra->data_device[g_main_device]; + kind = cudaMemcpyDeviceToDevice; + } else { + GGML_ASSERT(false); + } + if (split) { + // src0 = weight matrix is saved as a transposed matrix for better memory layout. + // dst is NOT transposed. + // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU. + // Instead they need to be copied to the correct slice in ne0 = dst row index. + // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results. + float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); + GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); + dhf_dst_i += src1_col_0*ne0 + row_low[id]; + CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_dd_i, row_diff*sizeof(float), + row_diff*sizeof(float), src1_ncols, kind, stream)); + } else { + float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); + GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); + dhf_dst_i += src1_col_0*ne0; + CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), kind, stream)); + } + } + + // add event for the main device to wait on until other device is done + if (split && (id != g_main_device || is != 0)) { + CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream)); + } + } + } + } + + for (int64_t id = 0; id < g_device_count; ++id) { + if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { + continue; + } + CUDA_CHECK(ggml_cuda_set_device(id)); + + // free buffers again when done + if (src0_as[id] > 0) { + ggml_cuda_pool_free(src0_dd[id], src0_as[id]); + } + if (src1_asf[id] > 0) { + ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]); + } + if (src1_asq[id] > 0) { + ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]); + } + if (dst_as[id] > 0) { + ggml_cuda_pool_free(dst_dd[id], dst_as[id]); + } + } + + // main device waits for all other devices to be finished + if (split && g_device_count > 1) { + int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; + is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS; + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + for (int64_t id = 0; id < g_device_count; ++id) { + if (row_low[id] == row_high[id]) { + continue; + } + for (int64_t is = 0; is < is_max; ++is) { + CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0)); + } + } + } + + if (dst->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + CUDA_CHECK(cudaDeviceSynchronize()); + } +} + +static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_repeat); +} + +static void ggml_cuda_get_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_get_rows); +} + +static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add); +} + +static void ggml_cuda_acc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_acc); +} + +static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul); +} + +static void ggml_cuda_div(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_div); +} + +static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu); +} + +static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu); +} + +static void ggml_cuda_gelu_quick(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu_quick); +} + +static void ggml_cuda_tanh(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_tanh); +} + +static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu); +} + +static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu); +} + +static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr); +} + +static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm); +} + +static void ggml_cuda_group_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_group_norm); +} + +static void ggml_cuda_concat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_concat); +} + +static void ggml_cuda_upscale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_upscale); +} + +static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad); +} + +static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm); +} + +bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + if (!g_cublas_loaded) return false; + + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + // TODO: find the optimal values for these + return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && + src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32); +} + +static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ + GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); + GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation + GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + + const int64_t ne12 = src1->ne[2]; + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + void * src0_ddq = src0_extra->data_device[g_main_device]; + + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); +} + +static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + GGML_ASSERT(!ggml_is_permuted(src0)); + GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + + const int64_t nb01 = src0->nb[1]; + const int64_t nb02 = src0->nb[2]; + + const int64_t ne12 = src1->ne[2]; + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + void * src0_ddq = src0_extra->data_device[g_main_device]; + + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + const int64_t row_stride_x = nb01 / sizeof(half); + const int64_t channel_stride_x = nb02 / sizeof(half); + + ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); +} + +static __global__ void k_compute_batched_ptrs( + const half * src0_as_f16, const half * src1_as_f16, half * dst_f16, + const void ** ptrs_src, void ** ptrs_dst, + int ne12, int ne13, + int ne23, + int nb02, int nb03, + int nb12, int nb13, + int nb2, int nb3, + int r2, int r3) { + int i13 = blockIdx.x * blockDim.x + threadIdx.x; + int i12 = blockIdx.y * blockDim.y + threadIdx.y; + + if (i13 >= ne13 || i12 >= ne12) { + return; + } + + int i03 = i13 / r3; + int i02 = i12 / r2; + + ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03; + ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2; + ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2; +} + +static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + + GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00); + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t nb01 = src0->nb[1]; + const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02); + const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03); + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const int64_t nb11 = src1->nb[1]; + const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); + const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); + + const int64_t ne1 = ggml_nelements(src1); + const int64_t ne = ggml_nelements(dst); + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream)); + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + void * src0_ddq = src0_extra->data_device[g_main_device]; + half * src0_as_f16 = (half *) src0_ddq; + + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + // convert src1 to fp16 + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + + size_t src1_as = 0; + half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); + to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); + + size_t dst_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + +#if 0 + // use cublasGemmEx + { + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + int i03 = i13 / r3; + int i02 = i12 / r2; + + CUBLAS_CHECK( + cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), + (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), + &beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } + } + } +#else + if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { + // there is no broadcast and src0, src1 are contiguous across dims 2, 3 + // use cublasGemmStridedBatchedEx + CUBLAS_CHECK( + cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA + (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB + &beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC + ne12*ne13, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + // use cublasGemmBatchedEx + const int ne23 = ne12*ne13; + + const void ** ptrs_src = nullptr; + void ** ptrs_dst = nullptr; + + size_t ptrs_src_s = 0; + size_t ptrs_dst_s = 0; + + ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s); + ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s); + + dim3 block_dims(ne13, ne12); + k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( + src0_as_f16, src1_as_f16, dst_f16, + ptrs_src, ptrs_dst, + ne12, ne13, + ne23, + nb02, nb03, + nb12, nb13, + dst->nb[2], dst->nb[3], + r2, r3); + CUDA_CHECK(cudaGetLastError()); + + CUBLAS_CHECK( + cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half), + (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float), + &beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01, + ne23, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + if (ptrs_src_s != 0) { + ggml_cuda_pool_free(ptrs_src, ptrs_src_s); + } + if (ptrs_dst_s != 0) { + ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s); + } + } +#endif + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); + + ggml_cuda_pool_free(src1_as_f16, src1_as); + ggml_cuda_pool_free(dst_f16, dst_as); +} + +static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const bool all_on_device = + (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && + (src1->backend == GGML_BACKEND_GPU) && + ( dst->backend == GGML_BACKEND_GPU); + + const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; + + int64_t min_compute_capability = INT_MAX; + for (int64_t id = 0; id < g_device_count; ++id) { + if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { + min_compute_capability = g_compute_capabilities[id]; + } + } + +#ifdef CUDA_USE_TENSOR_CORES + const bool use_tensor_cores = true; +#else + const bool use_tensor_cores = false; +#endif + + // debug helpers + //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); + //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); + //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]); + //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]); + //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); + //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); + + if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + // KQ single-batch + ggml_cuda_mul_mat_vec_p021(src0, src1, dst); + } else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + // KQV single-batch + ggml_cuda_mul_mat_vec_nc(src0, src1, dst); + } else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) { + // KQ + KQV multi-batch + ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); + } else if (src0->type == GGML_TYPE_F32) { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); + } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { + if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { +#ifdef GGML_CUDA_FORCE_DMMV + const bool use_mul_mat_vec_q = false; +#else + const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1; +#endif // GGML_CUDA_FORCE_DMMV + + if (use_mul_mat_vec_q) { + // NOTE: this kernel does not support ggml_nrows(src1) > 1 + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true); + } else { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); + } + } else { + bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type); + + // when tensor cores are available, use them for large batch size + // ref: https://github.com/ggerganov/llama.cpp/pull/3776 + if (use_tensor_cores && min_compute_capability >= CC_VOLTA && src1->ne[1] > MMQ_MAX_BATCH_SIZE) { + use_mul_mat_q = false; + } + + if (use_mul_mat_q) { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true); + } else { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); + } + } + } else { + GGML_ASSERT(false); + } +} + +#if 0 +template +static __global__ void k_compute_batched_ptrs_id( + const void ** ptrs_src, void ** ptrs_dst, + int ne12, int ne13, + int ne23, + int nb02, int nb03, + int nb12, int nb13, + int nb2, int nb3, + int r2, int r3, + ggml_type src0_type, half * src0_as_f16, int64_t src0_ne, + const half * src1_f16, half * dst_f16, + const int32_t * ids, const int id, + Srcs... src0s) { + + int i = ids[id]; + + half * src0_f16; + const void * srcs_ar[] = { (const half *) src0s... }; + if (src0_type == GGML_TYPE_F16) { + src0_f16 = (half *) srcs_ar[i]; + } else { + src0_f16 = src0_as_f16; + if (threadIdx.x == 0 && threadIdx.y == 0) { + const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type); + to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget); + } + } + + int i13 = blockIdx.x * blockDim.x + threadIdx.x; + int i12 = blockIdx.y * blockDim.y + threadIdx.y; + + if (i13 >= ne13 || i12 >= ne12) { + return; + } + + int i03 = i13 / r3; + int i02 = i12 / r2; + + ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02 + i03*nb03; + ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2; + ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2; +} + +static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) { + const struct ggml_tensor * ids = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src00 = dst->src[2]; + + const int id = dst->op_params[0]; + + GGML_ASSERT(!ggml_is_transposed(src00)); + GGML_ASSERT(!ggml_is_transposed(src1)); + + GGML_ASSERT(src00->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00); + const int64_t ne01 = src00->ne[1]; + const int64_t ne02 = src00->ne[2]; + const int64_t ne03 = src00->ne[3]; + + //const int64_t nb01 = src00->nb[1]; + const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02); + const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03); + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + //const int64_t nb11 = src1->nb[1]; + const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); + const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); + + const int64_t ne1 = ggml_nelements(src1); + const int64_t ne = ggml_nelements(dst); + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream)); + + //ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + //void * src0_ddq = src0_extra->data_device[g_main_device]; + //half * src0_as_f16 = (half *) src0_ddq; + + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + // convert src1 to fp16 + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + + size_t src1_as = 0; + half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); + to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); + + size_t dst_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + + // use cublasGemmBatchedEx + const int ne23 = ne12*ne13; + + const void ** ptrs_src = nullptr; + void ** ptrs_dst = nullptr; + + size_t ptrs_src_s = 0; + size_t ptrs_dst_s = 0; + + ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s); + ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s); + + int64_t src0_ne = ggml_nelements(src00); + half * src0_as_f16 = nullptr; + size_t src0_as = 0; + if (src00->type != GGML_TYPE_F16) { + src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as); + } + + static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6"); + dim3 block_dims(ne13, ne12); + k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>( + ptrs_src, ptrs_dst, + ne12, ne13, + ne23, + ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half), + nb12, nb13, + dst->nb[2], dst->nb[3], + r2, r3, + src00->type, src0_as_f16, src0_ne, + src1_as_f16, dst_f16, + (const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id, + dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr, + dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr, + dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr, + dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr + ); + CUDA_CHECK(cudaGetLastError()); + + CUBLAS_CHECK( + cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00, + (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10, + &beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01, + ne23, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + if (src0_as != 0) { + ggml_cuda_pool_free(src0_as_f16, src0_as); + } + if (ptrs_src_s != 0) { + ggml_cuda_pool_free(ptrs_src, ptrs_src_s); + } + if (ptrs_dst_s != 0) { + ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s); + } + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); + + ggml_cuda_pool_free(src1_as_f16, src1_as); + ggml_cuda_pool_free(dst_f16, dst_as); +} +#endif + +static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#if 0 + ggml_cuda_mul_mat_id_cublas(dst); + // TODO: mmq/mmv support +#endif + + GGML_ASSERT(dst->backend == GGML_BACKEND_GPU); + + const struct ggml_tensor * ids = src0; + const int32_t id = ((int32_t *) dst->op_params)[0]; + const int32_t n_as = ((int32_t *) dst->op_params)[1]; + + std::vector ids_host(ggml_nbytes(ids)); + + if (ids->backend == GGML_BACKEND_GPU) { + const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device]; + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0])); + CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0])); + } else { + memcpy(ids_host.data(), ids->data, ggml_nbytes(ids)); + } + + const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra; + const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra; + + ggml_tensor_extra_gpu src1_row_extra; + ggml_tensor_extra_gpu dst_row_extra; + + ggml_tensor src1_row = *src1; + ggml_tensor dst_row = *dst; + + src1_row.ne[1] = 1; + dst_row.ne[1] = 1; + + src1_row.nb[2] = src1_row.nb[1]; + dst_row.nb[2] = dst_row.nb[1]; + + src1_row.nb[3] = src1_row.nb[1]; + dst_row.nb[3] = dst_row.nb[1]; + + src1_row.extra = &src1_row_extra; + dst_row.extra = &dst_row_extra; + + + for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { + //int32_t row_id; + //CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0])); + //CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0])); + + const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); + + GGML_ASSERT(row_id >= 0 && row_id < n_as); + + const struct ggml_tensor * src0_row = dst->src[row_id + 2]; + + src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1]; + src1_row.data = (char *) src1->data + i01*src1->nb[1]; + + dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1]; + dst_row.data = (char *) dst->data + i01*dst->nb[1]; + + ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row); + } +} + +static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); +} + +static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp); +} + +static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const int64_t ne = ggml_nelements(src0); + GGML_ASSERT(ne == ggml_nelements(src1)); + + GGML_ASSERT(src0->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); + + GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); + GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + GGML_ASSERT(src0->ne[3] == 1); + + const int64_t nb00 = src0->nb[0]; + const int64_t nb01 = src0->nb[1]; + const int64_t nb02 = src0->nb[2]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + GGML_ASSERT(src1->ne[3] == 1); + + const int64_t nb10 = src1->nb[0]; + const int64_t nb11 = src1->nb[1]; + const int64_t nb12 = src1->nb[2]; + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + const ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + const ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + + char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; + char * src1_ddc = (char *) src1_extra->data_device[g_main_device]; + + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { + ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { + ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { + ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { + ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { + ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + } else { + fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ASSERT(false); + } + + (void) dst; +} + +static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + // TODO: why do we pass dst as src1 here? + ggml_cuda_cpy(src0, dst, nullptr); + (void) src1; +} + +static void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf); +} + +static void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_soft_max); +} + +static void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rope); +} + +static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); +} + +static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col); +} + +static void ggml_cuda_sum_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sum_rows); +} + +static void ggml_cuda_argsort(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_argsort); +} + +static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + (void) src0; + (void) src1; + (void) dst; +} + +static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]); +} + +void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { + const int64_t nrows = ggml_nrows(tensor); + + const int64_t ne0 = tensor->ne[0]; + + const size_t nb1 = tensor->nb[1]; + + ggml_backend_type backend = tensor->backend; + ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu; + memset(extra, 0, sizeof(*extra)); + + for (int64_t id = 0; id < g_device_count; ++id) { + if (backend == GGML_BACKEND_GPU && id != g_main_device) { + continue; + } + + ggml_cuda_set_device(id); + + int64_t row_low, row_high; + if (backend == GGML_BACKEND_GPU) { + row_low = 0; + row_high = nrows; + } else if (backend == GGML_BACKEND_GPU_SPLIT) { + const int64_t rounding = get_row_rounding(tensor->type); + + row_low = id == 0 ? 0 : nrows*g_tensor_split[id]; + row_low -= row_low % rounding; + + if (id == g_device_count - 1) { + row_high = nrows; + } else { + row_high = nrows*g_tensor_split[id + 1]; + row_high -= row_high % rounding; + } + } else { + GGML_ASSERT(false); + } + if (row_low == row_high) { + continue; + } + + int64_t nrows_split = row_high - row_low; + + const size_t offset_split = row_low*nb1; + size_t size = ggml_nbytes_split(tensor, nrows_split); + const size_t original_size = size; + + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + + char * buf; + CUDA_CHECK(cudaMalloc(&buf, size)); + char * buf_host = (char*)data + offset_split; + + // set padding to 0 to avoid possible NaN values + if (size > original_size) { + CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size)); + } + + CUDA_CHECK(cudaMemcpy(buf, buf_host, original_size, cudaMemcpyHostToDevice)); + + extra->data_device[id] = buf; + + if (backend == GGML_BACKEND_GPU_SPLIT) { + for (int64_t is = 0; is < MAX_STREAMS; ++is) { + CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming)); + } + } + } + + tensor->extra = extra; +} + +void ggml_cuda_free_data(struct ggml_tensor * tensor) { + if (!tensor || (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) ) { + return; + } + + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; + + for (int64_t id = 0; id < g_device_count; ++id) { + if (extra->data_device[id] != nullptr) { + CUDA_CHECK(ggml_cuda_set_device(id)); + CUDA_CHECK(cudaFree(extra->data_device[id])); + } + + for (int64_t is = 0; is < MAX_STREAMS; ++is) { + if (extra->events[id][is] != nullptr) { + CUDA_CHECK(ggml_cuda_set_device(id)); + CUDA_CHECK(cudaEventDestroy(extra->events[id][is])); + } + } + } + + delete extra; +} + +static ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr; +static size_t g_temp_tensor_extra_index = 0; + +static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { + if (g_temp_tensor_extras == nullptr) { + g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES]; + } + + size_t alloc_index = g_temp_tensor_extra_index; + g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES; + ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index]; + memset(extra, 0, sizeof(*extra)); + + return extra; +} + +static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) { + if (scratch && g_scratch_size == 0) { + return; + } + + tensor->backend = GGML_BACKEND_GPU; + + // recursively assign CUDA buffers until a compute tensor is found + if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) { + const ggml_op src0_op = tensor->src[0]->op; + if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) { + ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc); + } + } + if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) { + ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc); + } + + if (scratch && no_alloc) { + return; + } + + ggml_tensor_extra_gpu * extra; + + const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || + tensor->op == GGML_OP_VIEW || + force_inplace; + const size_t size = ggml_nbytes(tensor); + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) { + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; + char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; + size_t offset = 0; + if (tensor->op == GGML_OP_VIEW) { + memcpy(&offset, tensor->op_params, sizeof(size_t)); + } + extra = ggml_cuda_alloc_temp_tensor_extra(); + extra->data_device[g_main_device] = src0_ddc + offset; + } else if (tensor->op == GGML_OP_CPY) { + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra; + void * src1_ddv = src1_extra->data_device[g_main_device]; + extra = ggml_cuda_alloc_temp_tensor_extra(); + extra->data_device[g_main_device] = src1_ddv; + } else if (scratch) { + GGML_ASSERT(size <= g_scratch_size); + if (g_scratch_offset + size > g_scratch_size) { + g_scratch_offset = 0; + } + + char * data = (char *) g_scratch_buffer; + if (data == nullptr) { + CUDA_CHECK(cudaMalloc(&data, g_scratch_size)); + g_scratch_buffer = data; + } + extra = ggml_cuda_alloc_temp_tensor_extra(); + extra->data_device[g_main_device] = data + g_scratch_offset; + + g_scratch_offset += size; + + GGML_ASSERT(g_scratch_offset <= g_scratch_size); + } else { // allocate new buffers outside of scratch + void * data; + CUDA_CHECK(cudaMalloc(&data, size)); + CUDA_CHECK(cudaMemset(data, 0, size)); + extra = new ggml_tensor_extra_gpu; + memset(extra, 0, sizeof(*extra)); + extra->data_device[g_main_device] = data; + } + + tensor->extra = extra; +} + +void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) { + if (g_scratch_size == 0) { + return; + } + if (g_scratch_buffer == nullptr) { + ggml_cuda_set_device(g_main_device); + CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size)); + } + + ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra(); + + const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || + tensor->op == GGML_OP_VIEW; + + if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) { + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; + char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; + size_t view_offset = 0; + if (tensor->op == GGML_OP_VIEW) { + memcpy(&view_offset, tensor->op_params, sizeof(size_t)); + } + extra->data_device[g_main_device] = src0_ddc + view_offset; + } else { + extra->data_device[g_main_device] = (char *) g_scratch_buffer + offset; + } + + tensor->extra = extra; +} + +void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + CUDA_CHECK(cudaMemcpy(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice)); +} + +void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, true, false, false); +} + +void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, true, false, true); +} + +void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, false, false, false); +} + +void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, false, true, false); +} + +void ggml_cuda_set_main_device(const int main_device) { + if (main_device >= g_device_count) { + fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n", + main_device, g_device_count, g_main_device); + return; + } + + if (g_main_device != main_device && g_device_count > 1) { + g_main_device = main_device; + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, g_main_device)); + fprintf(stderr, "%s: using device %d (%s) as main device\n", __func__, g_main_device, prop.name); + } +} + +void ggml_cuda_set_scratch_size(const size_t scratch_size) { + // this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously + // it still won't always work as expected, but it's better than nothing + if (scratch_size > g_scratch_size) { + ggml_cuda_free_scratch(); + } + g_scratch_size = std::max(g_scratch_size, scratch_size); +} + +void ggml_cuda_free_scratch() { + if (g_scratch_buffer == nullptr) { + return; + } + + CUDA_CHECK(cudaFree(g_scratch_buffer)); + g_scratch_buffer = nullptr; +} + +bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { + if (!g_cublas_loaded) return false; + + ggml_cuda_func_t func; + const bool any_on_device = tensor->backend == GGML_BACKEND_GPU + || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) + || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU); + + if (!any_on_device && tensor->op != GGML_OP_MUL_MAT) { + return false; + } + + if (tensor->op == GGML_OP_MUL_MAT) { + if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) { +#ifndef NDEBUG + fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = " PRId64 ", src1->ne[3] = " PRId64 " - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]); +#endif + return false; + } + } + + switch (tensor->op) { + case GGML_OP_REPEAT: + func = ggml_cuda_repeat; + break; + case GGML_OP_GET_ROWS: + func = ggml_cuda_get_rows; + break; + case GGML_OP_DUP: + func = ggml_cuda_dup; + break; + case GGML_OP_ADD: + func = ggml_cuda_add; + break; + case GGML_OP_ACC: + func = ggml_cuda_acc; + break; + case GGML_OP_MUL: + func = ggml_cuda_mul; + break; + case GGML_OP_DIV: + func = ggml_cuda_div; + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_GELU: + func = ggml_cuda_gelu; + break; + case GGML_UNARY_OP_SILU: + func = ggml_cuda_silu; + break; + case GGML_UNARY_OP_GELU_QUICK: + func = ggml_cuda_gelu_quick; + break; + case GGML_UNARY_OP_TANH: + func = ggml_cuda_tanh; + break; + case GGML_UNARY_OP_RELU: + func = ggml_cuda_relu; + break; + default: + return false; + } + break; + case GGML_OP_NORM: + func = ggml_cuda_norm; + break; + case GGML_OP_GROUP_NORM: + func = ggml_cuda_group_norm; + break; + case GGML_OP_CONCAT: + func = ggml_cuda_concat; + break; + case GGML_OP_UPSCALE: + func = ggml_cuda_upscale; + break; + case GGML_OP_PAD: + func = ggml_cuda_pad; + break; + case GGML_OP_LEAKY_RELU: + func = ggml_cuda_leaky_relu; + break; + case GGML_OP_RMS_NORM: + func = ggml_cuda_rms_norm; + break; + case GGML_OP_MUL_MAT: + if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) { + return false; + } + func = ggml_cuda_mul_mat; + break; + case GGML_OP_MUL_MAT_ID: + if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[2], tensor->src[1], tensor)) { + return false; + } + func = ggml_cuda_mul_mat_id; + break; + case GGML_OP_SCALE: + func = ggml_cuda_scale; + break; + case GGML_OP_SQR: + func = ggml_cuda_sqr; + break; + case GGML_OP_CLAMP: + func = ggml_cuda_clamp; + break; + case GGML_OP_CPY: + func = ggml_cuda_cpy; + break; + case GGML_OP_CONT: + func = ggml_cuda_dup; + break; + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + func = ggml_cuda_nop; + break; + case GGML_OP_DIAG_MASK_INF: + func = ggml_cuda_diag_mask_inf; + break; + case GGML_OP_SOFT_MAX: + func = ggml_cuda_soft_max; + break; + case GGML_OP_ROPE: + func = ggml_cuda_rope; + break; + case GGML_OP_ALIBI: + func = ggml_cuda_alibi; + break; + case GGML_OP_IM2COL: + func = ggml_cuda_im2col; + break; + case GGML_OP_SUM_ROWS: + func = ggml_cuda_sum_rows; + break; + case GGML_OP_ARGSORT: + func = ggml_cuda_argsort; + break; + default: + return false; + } + + if (params->ith != 0) { + return true; + } + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return true; + } + func(tensor->src[0], tensor->src[1], tensor); + return true; +} + +int ggml_cuda_get_device_count() { + int device_count; + if (cudaGetDeviceCount(&device_count) != cudaSuccess) { + return 0; + } + return device_count; +} + +void ggml_cuda_get_device_description(int device, char * description, size_t description_size) { + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); + snprintf(description, description_size, "%s", prop.name); +} + +//////////////////////////////////////////////////////////////////////////////// + +// backend interface + +#define UNUSED GGML_UNUSED + +// cuda buffer + +struct ggml_backend_buffer_context_cuda { + int device; + void * dev_ptr = nullptr; + ggml_tensor_extra_gpu * temp_tensor_extras = nullptr; + size_t temp_tensor_extra_index = 0; + + ggml_backend_buffer_context_cuda(int device, void * dev_ptr) : device(device), dev_ptr(dev_ptr) {} + + ~ggml_backend_buffer_context_cuda() { + delete[] temp_tensor_extras; + } + + ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { + if (temp_tensor_extras == nullptr) { + temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES]; + } + + size_t alloc_index = temp_tensor_extra_index; + temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES; + ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index]; + memset(extra, 0, sizeof(*extra)); + + return extra; + } +}; + +static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; + CUDA_CHECK(cudaFree(ctx->dev_ptr)); + delete ctx; +} + +static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; + return ctx->dev_ptr; +} + +static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; + + if (tensor->view_src != NULL && tensor->view_offs == 0) { + assert(tensor->view_src->buffer->buft == buffer->buft); // TODO + tensor->backend = tensor->view_src->backend; + tensor->extra = tensor->view_src->extra; + return; + } + + ggml_tensor_extra_gpu * extra = ctx->ggml_cuda_alloc_temp_tensor_extra(); + + extra->data_device[ctx->device] = tensor->data; + + tensor->backend = GGML_BACKEND_GPU; + tensor->extra = extra; + + if (ggml_is_quantized(tensor->type)) { + // initialize padding to 0 to avoid possible NaN values + int64_t row_low = 0; + int64_t row_high = ggml_nrows(tensor); + int64_t nrows_split = row_high - row_low; + + size_t original_size = ggml_nbytes_split(tensor, nrows_split); + size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor); + + if (padded_size > original_size && tensor->view_src == nullptr) { + CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[ctx->device][0])); + } + } + + UNUSED(buffer); +} + +static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + + CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice)); + + UNUSED(buffer); +} + +static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + + CUDA_CHECK(cudaMemcpy(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost)); + + UNUSED(buffer); +} + +static struct ggml_backend_buffer_i cuda_backend_buffer_interface = { + /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer, + /* .get_base = */ ggml_backend_cuda_buffer_get_base, + /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor, + /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor, + /* .cpy_tensor_from = */ NULL, + /* .cpy_tensor_to = */ NULL, +}; + +// cuda buffer type + +static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + int device = (int) (intptr_t) buft->context; + + ggml_cuda_set_device(device); + + size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0 + + void * dev_ptr; + CUDA_CHECK(cudaMalloc(&dev_ptr, size)); + + ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda(device, dev_ptr); + + return ggml_backend_buffer_init(buft, cuda_backend_buffer_interface, ctx, size); +} + +static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return 128; + + UNUSED(buft); +} + +static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, ggml_tensor * tensor) { + int64_t row_low = 0; + int64_t row_high = ggml_nrows(tensor); + int64_t nrows_split = row_high - row_low; + + size_t size = ggml_nbytes_split(tensor, nrows_split); + + int64_t ne0 = tensor->ne[0]; + + if (ggml_is_quantized(tensor->type)) { + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + } + + return size; + + UNUSED(buft); +} + +static bool ggml_backend_cuda_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { + return ggml_backend_is_cuda(backend); + + UNUSED(buft); +} + +static ggml_backend_buffer_type_i cuda_backend_buffer_type_interface = { + /* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment, + /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size, + /* .supports_backend = */ ggml_backend_cuda_buffer_type_supports_backend, +}; + +ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { + static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda[GGML_CUDA_MAX_DEVICES]; + static bool ggml_backend_buffer_type_cuda_initialized = false; + if (!ggml_backend_buffer_type_cuda_initialized) { + for (int i = 0; i < GGML_CUDA_MAX_DEVICES; i++) { + ggml_backend_buffer_type_cuda[i] = { + /* .iface = */ cuda_backend_buffer_type_interface, + /* .context = */ (ggml_backend_buffer_type_context_t) (intptr_t) i, + }; + } + ggml_backend_buffer_type_cuda_initialized = true; + } + + return &ggml_backend_buffer_type_cuda[device]; +} + +// host buffer type + +static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; + CUDA_CHECK(cudaFreeHost(ctx->dev_ptr)); + delete ctx; +} + +static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + void * ptr; + CUDA_CHECK(cudaMallocHost(&ptr, size)); + + // FIXME: this is a hack to avoid having to implement a new buffer type + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer; + + return buffer; + + UNUSED(buft); +} + +struct ggml_backend_buffer_type_i cuda_backend_host_buffer_type_interface = { + /* .alloc_buffer = */ ggml_backend_cuda_host_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment, + /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, + /* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend, +}; + +ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() { + static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda_host = { + /* .iface = */ cuda_backend_host_buffer_type_interface, + /* .context = */ nullptr, + }; + + return &ggml_backend_buffer_type_cuda_host; +} + +// backend + +struct ggml_backend_context_cuda { + int device; +}; + +static const char * ggml_backend_cuda_name(ggml_backend_t backend) { + return GGML_CUDA_NAME; + + UNUSED(backend); +} + +static void ggml_backend_cuda_free(ggml_backend_t backend) { + ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; + + delete cuda_ctx; + delete backend; +} + +static ggml_backend_buffer_type_t ggml_backend_cuda_get_default_buffer_type(ggml_backend_t backend) { + ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; + + return ggml_backend_cuda_buffer_type(cuda_ctx->device); +} + +static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; + + GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + + CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[cuda_ctx->device][0])); +} + +static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; + + GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + + CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[cuda_ctx->device][0])); +} + +static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { + ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; + + CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[cuda_ctx->device][0])); + + UNUSED(backend); +} + +static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, ggml_cgraph * cgraph) { + GGML_ASSERT(!"not implemented"); + + return nullptr; + + UNUSED(backend); + UNUSED(cgraph); +} + +static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + GGML_ASSERT(!"not implemented"); + + UNUSED(backend); + UNUSED(plan); +} + +static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + GGML_ASSERT(!"not implemented"); + + UNUSED(backend); + UNUSED(plan); +} + +static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; + + ggml_cuda_set_main_device(cuda_ctx->device); + + ggml_compute_params params = {}; + params.type = GGML_TASK_COMPUTE; + params.ith = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE) + continue; + + assert(node->backend == GGML_BACKEND_GPU); + assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); + assert(node->extra != nullptr); + + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] != nullptr) { + assert(node->src[j]->backend == GGML_BACKEND_GPU); + assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); + assert(node->src[j]->extra != nullptr); + } + } + + bool ok = ggml_cuda_compute_forward(¶ms, node); + if (!ok) { + fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + +#if 0 + if (node->type == GGML_TYPE_F32) { + cudaDeviceSynchronize(); + std::vector tmp(ggml_nelements(node), 0.0f); + cudaMemcpy(tmp.data(), node->data, ggml_nelements(node)*sizeof(float), cudaMemcpyDeviceToHost); + printf("\n%s (%s) (%s %s) (%s %s): ", node->name, ggml_op_name(node->op), + ggml_type_name(node->src[0]->type), + node->src[1] ? ggml_type_name(node->src[1]->type) : "none", + node->src[0]->name, + node->src[1] ? node->src[1]->name : "none"); + double sum = 0.0; + double sq_sum = 0.0; + for (int i = 0; i < ggml_nelements(node); i++) { + printf("%f ", tmp[i]); + sum += tmp[i]; + sq_sum += tmp[i]*tmp[i]; + } + printf("\n"); + printf("sum: %f, ", sum); + printf("sq_sum: %f\n", sq_sum); + } +#endif + } + + UNUSED(backend); +} + +static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_TANH: + return true; + default: + return false; + } + break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + { + struct ggml_tensor * a; + struct ggml_tensor * b; + if (op->op == GGML_OP_MUL_MAT) { + a = op->src[0]; + b = op->src[1]; + } else { + a = op->src[2]; + b = op->src[1]; + } + if (a->ne[3] != b->ne[3]) { + return false; + } + return true; + } break; + case GGML_OP_GET_ROWS: + { + switch (op->src[0]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + return true; + default: + return false; + } + } break; + case GGML_OP_CPY: + { + ggml_type src0_type = op->src[0]->type; + ggml_type src1_type = op->src[1]->type; + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) { + return true; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return true; + } + return false; + } break; + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NORM: + case GGML_OP_REPEAT: + case GGML_OP_DUP: + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_RMS_NORM: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_CLAMP: + case GGML_OP_CONT: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_ROPE: + case GGML_OP_ALIBI: + case GGML_OP_IM2COL: + case GGML_OP_SUM_ROWS: + case GGML_OP_ARGSORT: + case GGML_OP_ACC: + case GGML_OP_CONCAT: + case GGML_OP_GROUP_NORM: + case GGML_OP_UPSCALE: + case GGML_OP_PAD: + case GGML_OP_LEAKY_RELU: + return true; + default: + return false; + } + + UNUSED(backend); +} + +static ggml_backend_i cuda_backend_i = { + /* .get_name = */ ggml_backend_cuda_name, + /* .free = */ ggml_backend_cuda_free, + /* .get_default_buffer_type = */ ggml_backend_cuda_get_default_buffer_type, + /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, + /* .cpy_tensor_from_async = */ NULL, + /* .cpy_tensor_to_async = */ NULL, + /* .synchronize = */ ggml_backend_cuda_synchronize, + /* .graph_plan_create = */ ggml_backend_cuda_graph_plan_create, + /* .graph_plan_free = */ ggml_backend_cuda_graph_plan_free, + /* .graph_plan_compute = */ ggml_backend_cuda_graph_plan_compute, + /* .graph_compute = */ ggml_backend_cuda_graph_compute, + /* .supports_op = */ ggml_backend_cuda_supports_op, +}; + +ggml_backend_t ggml_backend_cuda_init(int device) { + ggml_init_cublas(); // TODO: remove from ggml.c + + if (device < 0 || device >= ggml_cuda_get_device_count()) { + fprintf(stderr, "%s: error: invalid device %d\n", __func__, device); + return nullptr; + } + + // not strictly necessary, but it may reduce the overhead of the first graph_compute + ggml_cuda_set_main_device(device); + + ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda { + /* .device = */ device + }; + + ggml_backend_t cuda_backend = new ggml_backend { + /* .interface = */ cuda_backend_i, + /* .context = */ ctx + }; + + return cuda_backend; +} + +bool ggml_backend_is_cuda(ggml_backend_t backend) { + return backend->iface.get_name == ggml_backend_cuda_name; +} + +static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * user_data) { + ggml_backend_t cuda_backend = ggml_backend_cuda_init((int) (intptr_t) user_data); + return cuda_backend; + + UNUSED(params); +} + +extern "C" int ggml_backend_cuda_reg_devices(); + +int ggml_backend_cuda_reg_devices() { + int device_count = ggml_cuda_get_device_count(); + //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization + for (int i = 0; i < device_count; i++) { + char name[128]; + snprintf(name, sizeof(name), "%s%d", GGML_CUDA_NAME, i); + ggml_backend_register(name, ggml_backend_reg_cuda_init, ggml_backend_cuda_buffer_type(i), (void *) (intptr_t) i); + } + return device_count; +} diff --git a/applications/src/llama/ggml-cuda.h b/applications/src/llama/ggml-cuda.h new file mode 100644 index 000000000..cdb0c0c41 --- /dev/null +++ b/applications/src/llama/ggml-cuda.h @@ -0,0 +1,64 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef GGML_USE_HIPBLAS +#define GGML_CUDA_NAME "ROCm" +#define GGML_CUBLAS_NAME "hipBLAS" +#else +#define GGML_CUDA_NAME "CUDA" +#define GGML_CUBLAS_NAME "cuBLAS" +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_CUDA_MAX_DEVICES 16 + +// Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`. +GGML_API void ggml_init_cublas(void); + +// Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`. +GGML_API bool ggml_cublas_loaded(void); + +GGML_API void * ggml_cuda_host_malloc(size_t size); +GGML_API void ggml_cuda_host_free(void * ptr); + +GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); +GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split); +GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor); +GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor); + +GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor); +GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor); +GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor); + +GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor); +GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset); +GGML_API void ggml_cuda_copy_to_device(struct ggml_tensor * tensor); + +GGML_API void ggml_cuda_set_main_device(int main_device); +GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q); +GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size); +GGML_API void ggml_cuda_free_scratch(void); +GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor); + +GGML_API int ggml_cuda_get_device_count(void); +GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size); + +// backend API +GGML_API ggml_backend_t ggml_backend_cuda_init(int device); + +GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend); +GGML_API int ggml_backend_cuda_get_device(ggml_backend_t backend); + +GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device); + +// pinned host buffer for use with CPU backend for faster copies between CPU and GPU +GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void); + +#ifdef __cplusplus +} +#endif diff --git a/applications/src/llama/ggml-impl.h b/applications/src/llama/ggml-impl.h new file mode 100644 index 000000000..1f5610a86 --- /dev/null +++ b/applications/src/llama/ggml-impl.h @@ -0,0 +1,243 @@ +#pragma once + +#include "ggml.h" + +// GGML internal header + +#include +#include +#include +#include // memcpy +#include // fabsf + +#ifdef __cplusplus +extern "C" { +#endif + +// static_assert should be a #define, but if it's not, +// fall back to the _Static_assert C11 keyword. +// if C99 - static_assert is noop +// ref: https://stackoverflow.com/a/53923785/4039976 +#ifndef static_assert +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) +#define static_assert(cond, msg) _Static_assert(cond, msg) +#else +#define static_assert(cond, msg) struct global_scope_noop_trick +#endif +#endif + +// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 +#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) +#ifndef __FMA__ +#define __FMA__ +#endif +#ifndef __F16C__ +#define __F16C__ +#endif +#ifndef __SSE3__ +#define __SSE3__ +#endif +#endif + +// 16-bit float +// on Arm, we use __fp16 +// on x86, we use uint16_t +#if defined(__ARM_NEON) && !defined(_MSC_VER) + +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include + +#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x)) +#define GGML_COMPUTE_FP32_TO_FP16(x) (x) + +#define GGML_FP16_TO_FP32(x) ((float) (x)) +#define GGML_FP32_TO_FP16(x) (x) + +#else + +#ifdef __wasm_simd128__ +#include +#else +#ifdef __POWER9_VECTOR__ +#include +#undef bool +#define bool _Bool +#else +#if defined(_MSC_VER) || defined(__MINGW32__) +#include +#else +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) +#if !defined(__riscv) +#include +#endif +#endif +#endif +#endif +#endif + +#ifdef __riscv_v_intrinsic +#include +#endif + +#ifdef __F16C__ + +#ifdef _MSC_VER +#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) +#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) +#else +#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) +#endif + +#elif defined(__POWER9_VECTOR__) + +#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) +/* the inline asm below is about 12% faster than the lookup method */ +#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x) +#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) + +static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + register float f; + register double d; + __asm__( + "mtfprd %0,%2\n" + "xscvhpdp %0,%0\n" + "frsp %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=f"(f): + /* in */ "r"(h)); + return f; +} + +static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { + register double d; + register ggml_fp16_t r; + __asm__( /* xscvdphp can work on double or single precision */ + "xscvdphp %0,%2\n" + "mffprd %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=r"(r): + /* in */ "f"(f)); + return r; +} + +#else + +// FP16 <-> FP32 +// ref: https://github.com/Maratyszcza/FP16 + +static inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32; + fp32.as_bits = w; + return fp32.as_value; +} + +static inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32; + fp32.as_value = f; + return fp32.as_bits; +} + +static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + const uint32_t w = (uint32_t) h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + const uint32_t exp_offset = UINT32_C(0xE0) << 23; +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float exp_scale = 0x1.0p-112f; +#else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); +#endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; +#else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); +#endif + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +} + +#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) + +#endif // __F16C__ + +#endif // __ARM_NEON + +// precomputed f32 table for f16 (256 KB) +// defined in ggml.c, initialized in ggml_init() +extern float ggml_table_f32_f16[1 << 16]; + +// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, +// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. +// This is also true for POWER9. +#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) + +inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { + uint16_t s; + memcpy(&s, &f, sizeof(uint16_t)); + return ggml_table_f32_f16[s]; +} + +#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) +#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) + +#endif + +#define GGML_HASHTABLE_FULL ((size_t)-1) +#define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2) + +bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml_tensor * key); + +// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted +size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key); + +// returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full +size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key); + +// return index, asserts if table is full +size_t ggml_hash_find_or_insert( struct ggml_hash_set hash_set, struct ggml_tensor * key); + +#ifdef __cplusplus +} +#endif diff --git a/applications/src/llama/ggml-quants.c b/applications/src/llama/ggml-quants.c new file mode 100644 index 000000000..0e8163a16 --- /dev/null +++ b/applications/src/llama/ggml-quants.c @@ -0,0 +1,7382 @@ +#include "ggml-quants.h" +#include "ggml-impl.h" + +#include +#include +#include +#include + +#ifdef __ARM_NEON + +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include + +#else + +#ifdef __wasm_simd128__ +#include +#else +#if defined(__POWER9_VECTOR__) || defined(__powerpc64__) +#include +#undef bool +#define bool _Bool +#else +#if defined(_MSC_VER) || defined(__MINGW32__) +#include +#else +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) +#if !defined(__riscv) +#include +#endif +#endif +#endif +#endif +#endif +#endif + +#ifdef __riscv_v_intrinsic +#include +#endif + +#undef MIN +#undef MAX + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) +// multiply int8_t, add results pairwise twice +static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { + // Get absolute values of x vectors + const __m128i ax = _mm_sign_epi8(x, x); + // Sign the values of the y vectors + const __m128i sy = _mm_sign_epi8(y, x); + // Perform multiplication and create 16-bit values + const __m128i dot = _mm_maddubs_epi16(ax, sy); + const __m128i ones = _mm_set1_epi16(1); + return _mm_madd_epi16(ones, dot); +} + +#if __AVX__ || __AVX2__ || __AVX512F__ +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + const __m128i hi64 = _mm_unpackhi_epi64(a, a); + const __m128i sum64 = _mm_add_epi32(hi64, a); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +#if defined(__AVX2__) || defined(__AVX512F__) +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = _mm256_set_epi64x( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); + const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytes = _mm256_or_si256(bytes, bit_mask); + return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); + const __m256i lowMask = _mm256_set1_epi8( 0xF ); + return _mm256_and_si256(lowMask, bytes); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + const __m256i summed_pairs = _mm256_madd_epi16(ones, x); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { +#if __AVXVNNI__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_float(dot); +#endif +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { +#if __AVXVNNIINT8__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + return mul_sum_us8_pairs_float(ax, sy); +#endif +} + +static inline __m128i packNibbles( __m256i bytes ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh +#if __AVX512F__ + const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 + bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh + return _mm256_cvtepi16_epi8(bytes); // abcd_efgh +#else + const __m256i lowByte = _mm256_set1_epi16( 0xFF ); + __m256i high = _mm256_andnot_si256( lowByte, bytes ); + __m256i low = _mm256_and_si256( lowByte, bytes ); + high = _mm256_srli_epi16( high, 4 ); + bytes = _mm256_or_si256( low, high ); + + // Compress uint16_t lanes into bytes + __m128i r0 = _mm256_castsi256_si128( bytes ); + __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); + return _mm_packus_epi16( r0, r1 ); +#endif +} +#elif defined(__AVX__) +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202); + __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl); + __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh); + const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytesl = _mm_or_si128(bytesl, bit_mask); + bytesh = _mm_or_si128(bytesh, bit_mask); + bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); + bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); + return MM256_SET_M128I(bytesh, bytesl); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + // Load 16 bytes from memory + __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); + __m128i tmph = _mm_srli_epi16(tmpl, 4); + const __m128i lowMask = _mm_set1_epi8(0xF); + tmpl = _mm_and_si128(lowMask, tmpl); + tmph = _mm_and_si128(lowMask, tmph); + return MM256_SET_M128I(tmph, tmpl); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { + const __m128i ones = _mm_set1_epi16(1); + const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); + const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); + const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + const __m128i axl = _mm256_castsi256_si128(ax); + const __m128i axh = _mm256_extractf128_si256(ax, 1); + const __m128i syl = _mm256_castsi256_si128(sy); + const __m128i syh = _mm256_extractf128_si256(sy, 1); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + const __m128i xl = _mm256_castsi256_si128(x); + const __m128i xh = _mm256_extractf128_si256(x, 1); + const __m128i yl = _mm256_castsi256_si128(y); + const __m128i yh = _mm256_extractf128_si256(y, 1); + // Get absolute values of x vectors + const __m128i axl = _mm_sign_epi8(xl, xl); + const __m128i axh = _mm_sign_epi8(xh, xh); + // Sign the values of the y vectors + const __m128i syl = _mm_sign_epi8(yl, xl); + const __m128i syh = _mm_sign_epi8(yh, xh); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh + const __m128i lowByte = _mm_set1_epi16( 0xFF ); + __m128i high = _mm_andnot_si128( lowByte, bytes1 ); + __m128i low = _mm_and_si128( lowByte, bytes1 ); + high = _mm_srli_epi16( high, 4 ); + bytes1 = _mm_or_si128( low, high ); + high = _mm_andnot_si128( lowByte, bytes2 ); + low = _mm_and_si128( lowByte, bytes2 ); + high = _mm_srli_epi16( high, 4 ); + bytes2 = _mm_or_si128( low, high ); + + return _mm_packus_epi16( bytes1, bytes2); +} +#endif +#elif defined(__SSSE3__) +// horizontally add 4x4 floats +static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { + __m128 res_0 =_mm_hadd_ps(a, b); + __m128 res_1 =_mm_hadd_ps(c, d); + __m128 res =_mm_hadd_ps(res_0, res_1); + res =_mm_hadd_ps(res, res); + res =_mm_hadd_ps(res, res); + + return _mm_cvtss_f32(res); +} +#endif // __AVX__ || __AVX2__ || __AVX512F__ +#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) + +#if defined(__ARM_NEON) +#if !defined(__aarch64__) + +// 64-bit compatibility + +// vaddvq_s16 +// vpaddq_s16 +// vaddvq_s32 +// vaddvq_f32 +// vmaxvq_f32 +// vcvtnq_s32_f32 + +inline static int32_t vaddvq_s16(int16x8_t v) { + return + (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + + (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + + (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + + (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); +} + +inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { + int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); + int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); + return vcombine_s16(a0, b0); +} + +inline static int32_t vaddvq_s32(int32x4_t v) { + return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); +} + +inline static float vaddvq_f32(float32x4_t v) { + return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); +} + +inline static float vmaxvq_f32(float32x4_t v) { + return + MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), + MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); +} + +inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { + int32x4_t res; + + res[0] = roundf(vgetq_lane_f32(v, 0)); + res[1] = roundf(vgetq_lane_f32(v, 1)); + res[2] = roundf(vgetq_lane_f32(v, 2)); + res[3] = roundf(vgetq_lane_f32(v, 3)); + + return res; +} + +// vld1q_s16_x2 +// vld1q_u8_x2 +// vld1q_u8_x4 +// vld1q_s8_x2 +// vld1q_s8_x4 +// TODO: double-check these work correctly + +typedef struct ggml_int16x8x2_t { + int16x8_t val[2]; +} ggml_int16x8x2_t; + +inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) { + ggml_int16x8x2_t res; + + res.val[0] = vld1q_s16(ptr + 0); + res.val[1] = vld1q_s16(ptr + 8); + + return res; +} + +typedef struct ggml_uint8x16x2_t { + uint8x16_t val[2]; +} ggml_uint8x16x2_t; + +inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) { + ggml_uint8x16x2_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + + return res; +} + +typedef struct ggml_uint8x16x4_t { + uint8x16_t val[4]; +} ggml_uint8x16x4_t; + +inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) { + ggml_uint8x16x4_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + res.val[2] = vld1q_u8(ptr + 32); + res.val[3] = vld1q_u8(ptr + 48); + + return res; +} + +typedef struct ggml_int8x16x2_t { + int8x16_t val[2]; +} ggml_int8x16x2_t; + +inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) { + ggml_int8x16x2_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + + return res; +} + +typedef struct ggml_int8x16x4_t { + int8x16_t val[4]; +} ggml_int8x16x4_t; + +inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) { + ggml_int8x16x4_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + res.val[2] = vld1q_s8(ptr + 32); + res.val[3] = vld1q_s8(ptr + 48); + + return res; +} + +#else + +#define ggml_int16x8x2_t int16x8x2_t +#define ggml_uint8x16x2_t uint8x16x2_t +#define ggml_uint8x16x4_t uint8x16x4_t +#define ggml_int8x16x2_t int8x16x2_t +#define ggml_int8x16x4_t int8x16x4_t + +#define ggml_vld1q_s16_x2 vld1q_s16_x2 +#define ggml_vld1q_u8_x2 vld1q_u8_x2 +#define ggml_vld1q_u8_x4 vld1q_u8_x4 +#define ggml_vld1q_s8_x2 vld1q_s8_x2 +#define ggml_vld1q_s8_x4 vld1q_s8_x4 + +#endif +#endif + +#if defined(__ARM_NEON) || defined(__wasm_simd128__) +#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s +#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) +#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) +#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) +#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) +#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) +#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) +#define B8(c,s ) B7(c,s, c), B7(c,s, s) + +// precomputed tables for expanding 8bits to 8 bytes: +static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 +static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 +#endif + +// reference implementation for deterministic creation of model files +void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { + static const int qk = QK4_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < qk/2; ++j) { + const float x0 = x[i*qk + 0 + j]*id; + const float x1 = x[i*qk + qk/2 + j]*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); + + y[i].qs[j] = xi0; + y[i].qs[j] |= xi1 << 4; + } + } +} + +void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { + quantize_row_q4_0_reference(x, y, k); +} + +void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) { + const int qk = QK4_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + y[i].m = GGML_FP32_TO_FP16(min); + + for (int j = 0; j < qk/2; ++j) { + const float x0 = (x[i*qk + 0 + j] - min)*id; + const float x1 = (x[i*qk + qk/2 + j] - min)*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); + + y[i].qs[j] = xi0; + y[i].qs[j] |= xi1 << 4; + } + } +} + +void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { + quantize_row_q4_1_reference(x, y, k); +} + +void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { + static const int qk = QK5_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + uint32_t qh = 0; + + for (int j = 0; j < qk/2; ++j) { + const float x0 = x[i*qk + 0 + j]*id; + const float x1 = x[i*qk + qk/2 + j]*id; + + const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); + + y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2); + } + + memcpy(&y[i].qh, &qh, sizeof(qh)); + } +} + +void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) { + quantize_row_q5_0_reference(x, y, k); +} + +void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { + const int qk = QK5_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 5) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + y[i].m = GGML_FP32_TO_FP16(min); + + uint32_t qh = 0; + + for (int j = 0; j < qk/2; ++j) { + const float x0 = (x[i*qk + 0 + j] - min)*id; + const float x1 = (x[i*qk + qk/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2); + } + + memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); + } +} + +void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) { + quantize_row_q5_1_reference(x, y, k); +} + +// reference implementation for deterministic creation of model files +void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = x[i*QK8_0 + j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = x[i*QK8_0 + j]*id; + + y[i].qs[j] = roundf(x0); + } + } +} + +void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } + } +#elif defined(__wasm_simd128__) + for (int i = 0; i < nb; i++) { + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), + wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), + wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); + + y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); + y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); + y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); + y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); + } + } +#elif defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 127.f; + y[i].d = GGML_FP32_TO_FP16(d); + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#elif defined(__riscv_v_intrinsic) + + size_t vl = __riscv_vsetvl_e32m4(QK8_0); + + for (int i = 0; i < nb; i++) { + // load elements + vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl); + + vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + + // convert to integer + vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); + vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + + // store result + __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_0_reference(x, y, k); +#endif +} + +// reference implementation for deterministic creation of model files +void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) { + assert(QK8_1 == 32); + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_1; j++) { + const float v = x[i*QK8_1 + j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + int sum = 0; + + for (int j = 0; j < QK8_1/2; ++j) { + const float v0 = x[i*QK8_1 + j]*id; + const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id; + + y[i].qs[ j] = roundf(v0); + y[i].qs[QK8_1/2 + j] = roundf(v1); + + sum += y[i].qs[ j]; + sum += y[i].qs[QK8_1/2 + j]; + } + + y[i].s = sum*d; + } +} + +void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + block_q8_1 * restrict y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + int32x4_t accv = vdupq_n_s32(0); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + + accv = vaddq_s32(accv, vi); + } + + y[i].s = d * vaddvq_s32(accv); + } +#elif defined(__wasm_simd128__) + for (int i = 0; i < nb; i++) { + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), + wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), + wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + v128_t accv = wasm_i32x4_splat(0); + + for (int j = 0; j < 8; j++) { + const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); + + y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); + y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); + y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); + y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); + + accv = wasm_i32x4_add(accv, vi); + } + + y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) + + wasm_i32x4_extract_lane(accv, 1) + + wasm_i32x4_extract_lane(accv, 2) + + wasm_i32x4_extract_lane(accv, 3)); + } +#elif defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 127.f; + y[i].d = d; + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Compute the sum of the quants and set y[i].s + y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Compute the sum of the quants and set y[i].s + const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); + const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); + y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1)); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#elif defined(__riscv_v_intrinsic) + + size_t vl = __riscv_vsetvl_e32m4(QK8_1); + + for (int i = 0; i < nb; i++) { + // load elements + vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl); + + vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + + // convert to integer + vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); + vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + + // store result + __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + + // compute sum for y[i].s + vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl); + + // set y[i].s + int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); + y[i].s = sum*d; + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_1_reference(x, y, k); +#endif +} + +void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) { + static const int qk = QK4_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F) - 8; + const int x1 = (x[i].qs[j] >> 4) - 8; + + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} + +void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) { + static const int qk = QK4_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + const float m = GGML_FP16_TO_FP32(x[i].m); + + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F); + const int x1 = (x[i].qs[j] >> 4); + + y[i*qk + j + 0 ] = x0*d + m; + y[i*qk + j + qk/2] = x1*d + m; + } + } +} + +void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) { + static const int qk = QK5_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; + const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; + + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} + +void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) { + static const int qk = QK5_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + const float m = GGML_FP16_TO_FP32(x[i].m); + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int x0 = (x[i].qs[j] & 0x0F) | xh_0; + const int x1 = (x[i].qs[j] >> 4) | xh_1; + + y[i*qk + j + 0 ] = x0*d + m; + y[i*qk + j + qk/2] = x1*d + m; + } + } +} + +void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k) { + static const int qk = QK8_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int j = 0; j < qk; ++j) { + y[i*qk + j] = x[i].qs[j]*d; + } + } +} + +// +// 2-6 bit quantization in super-blocks +// + +// +// ===================== Helper functions +// +static inline int nearest_int(float fval) { + assert(fval <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) { + float max = 0; + float amax = 0; + for (int i = 0; i < n; ++i) { + float ax = fabsf(x[i]); + if (ax > amax) { amax = ax; max = x[i]; } + } + if (amax < 1e-30f) { // all zero + for (int i = 0; i < n; ++i) { + L[i] = 0; + } + return 0.f; + } + float iscale = -nmax / max; + if (rmse_type == 0) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + return 1/iscale; + } + bool return_early = false; + if (rmse_type < 0) { + rmse_type = -rmse_type; + return_early = true; + } + int weight_type = rmse_type%2; + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l + nmax; + float w = weight_type == 1 ? x[i] * x[i] : 1; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + float scale = sumlx/suml2; + if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale; + float best = scale * sumlx; + for (int is = -9; is <= 9; ++is) { + if (is == 0) { + continue; + } + iscale = -(nmax + 0.1f*is) / max; + sumlx = suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + float w = weight_type == 1 ? x[i] * x[i] : 1; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + if (suml2 > 0 && sumlx*sumlx > best*suml2) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + scale = sumlx/suml2; best = scale*sumlx; + } + } + return scale; +} + +static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) { + float max = 0; + float amax = 0; + for (int i = 0; i < n; ++i) { + float ax = fabsf(x[i]); + if (ax > amax) { amax = ax; max = x[i]; } + } + if (!amax) { // all zero + for (int i = 0; i < n; ++i) { L[i] = 0; } + return 0.f; + } + float iscale = -nmax / max; + if (do_rmse) { + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l; + float w = x[i]*x[i]; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + for (int itry = 0; itry < 5; ++itry) { + int n_changed = 0; + for (int i = 0; i < n; ++i) { + float w = x[i]*x[i]; + float slx = sumlx - w*x[i]*L[i]; + if (slx > 0) { + float sl2 = suml2 - w*L[i]*L[i]; + int new_l = nearest_int(x[i] * sl2 / slx); + new_l = MAX(-nmax, MIN(nmax-1, new_l)); + if (new_l != L[i]) { + slx += w*x[i]*new_l; + sl2 += w*new_l*new_l; + if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) { + L[i] = new_l; sumlx = slx; suml2 = sl2; + ++n_changed; + } + } + } + } + if (!n_changed) { + break; + } + } + for (int i = 0; i < n; ++i) { + L[i] += nmax; + } + return sumlx / suml2; + } + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l + nmax; + } + return 1/iscale; +} + +static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, + int ntry, float alpha) { + float min = x[0]; + float max = x[0]; + for (int i = 1; i < n; ++i) { + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + } + if (max == min) { + for (int i = 0; i < n; ++i) L[i] = 0; + *the_min = 0; + return 0.f; + } + if (min > 0) min = 0; + float iscale = nmax/(max - min); + float scale = 1/iscale; + for (int itry = 0; itry < ntry; ++itry) { + float sumlx = 0; int suml2 = 0; + bool did_change = false; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + if (l != L[i]) { + L[i] = l; + did_change = true; + } + sumlx += (x[i] - min)*l; + suml2 += l*l; + } + scale = sumlx/suml2; + float sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] - scale*L[i]; + } + min = alpha*min + (1 - alpha)*sum/n; + if (min > 0) min = 0; + iscale = 1/scale; + if (!did_change) break; + } + *the_min = -min; + return scale; +} + +static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights, + uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux, + float rmin, float rdelta, int nstep, bool use_mad) { + float min = x[0]; + float max = x[0]; + float sum_w = weights[0]; + float sum_x = sum_w * x[0]; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 1; i < n; ++i) { +#else + for (int i = 1; i < n; ++i) { +#endif + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + float w = weights[i]; + sum_w += w; + sum_x += w * x[i]; + } + if (min > 0) min = 0; + if (max == min) { + for (int i = 0; i < n; ++i) L[i] = 0; + *the_min = -min; + return 0.f; + } + float iscale = nmax/(max - min); + float scale = 1/iscale; + float best_mad = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + L[i] = MAX(0, MIN(nmax, l)); + float diff = scale * L[i] + min - x[i]; + diff = use_mad ? fabsf(diff) : diff * diff; + float w = weights[i]; + best_mad += w * diff; + } + if (nstep < 1) { + *the_min = -min; + return scale; + } + for (int is = 0; is <= nstep; ++is) { + iscale = (rmin + rdelta*is + nmax)/(max - min); + float sum_l = 0, sum_l2 = 0, sum_xl = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + Laux[i] = l; + float w = weights[i]; + sum_l += w*l; + sum_l2 += w*l*l; + sum_xl += w*l*x[i]; + } + float D = sum_w * sum_l2 - sum_l * sum_l; + if (D > 0) { + float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; + float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; + if (this_min > 0) { + this_min = 0; + this_scale = sum_xl / sum_l2; + } + float mad = 0; + for (int i = 0; i < n; ++i) { + float diff = this_scale * Laux[i] + this_min - x[i]; + diff = use_mad ? fabsf(diff) : diff * diff; + float w = weights[i]; + mad += w * diff; + } + if (mad < best_mad) { + for (int i = 0; i < n; ++i) { + L[i] = Laux[i]; + } + best_mad = mad; + scale = this_scale; + min = this_min; + } + } + } + *the_min = -min; + return scale; +} + +#if QK_K == 256 +static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { + if (j < 4) { + *d = q[j] & 63; *m = q[j + 4] & 63; + } else { + *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} +#endif + +//========================- 2-bit (de)-quantization + +void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + uint8_t L[QK_K]; + uint8_t Laux[16]; + float weights[16]; + float mins[QK_K/16]; + float scales[QK_K/16]; + + const float q4scale = 15.f; + + for (int i = 0; i < nb; i++) { + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]); + scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + if (max_scale > 0) { + float iscale = q4scale/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*scales[j]); + y[i].scales[j] = l; + } + y[i].d = GGML_FP32_TO_FP16(max_scale/q4scale); + } else { + for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0; + y[i].d = GGML_FP32_TO_FP16(0.f); + } + if (max_min > 0) { + float iscale = q4scale/max_min; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*mins[j]); + y[i].scales[j] |= (l << 4); + } + y[i].dmin = GGML_FP32_TO_FP16(max_min/q4scale); + } else { + y[i].dmin = GGML_FP32_TO_FP16(0.f); + } + for (int j = 0; j < QK_K/16; ++j) { + const float d = GGML_FP16_TO_FP32(y[i].d) * (y[i].scales[j] & 0xF); + if (!d) continue; + const float dm = GGML_FP16_TO_FP32(y[i].dmin) * (y[i].scales[j] >> 4); + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int((x[16*j + ii] + dm)/d); + l = MAX(0, MIN(3, l)); + L[16*j + ii] = l; + } + } + +#if QK_K == 256 + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } +#else + for (int l = 0; l < 16; ++l) { + y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); + } +#endif + + x += QK_K; + + } +} + +void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + const float min = GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * q = x[i].qs; + +#if QK_K == 256 + int is = 0; + float dl, ml; + for (int n = 0; n < QK_K; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + uint8_t sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml; + + sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml; + + shift += 2; + } + q += 32; + } +#else + float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4); + float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4); + float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4); + float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4); + for (int l = 0; l < 16; ++l) { + y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1; + y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2; + y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3; + y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4; + } + y += QK_K; +#endif + } +} + +void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) { + quantize_row_q2_K_reference(x, vy, k); +} + +size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + (void)hist; // TODO: collect histograms + + for (int j = 0; j < n; j += k) { + block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K; + quantize_row_q2_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q2_K)); +} + +//========================= 3-bit (de)-quantization + +void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + int8_t L[QK_K]; + float scales[QK_K / 16]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; + float amax = 0; + for (int j = 0; j < QK_K/16; ++j) { + scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true); + float scale = fabsf(scales[j]); + if (scale > amax) { + amax = scale; max_scale = scales[j]; + } + } + +#if QK_K == 256 + memset(y[i].scales, 0, 12); + if (max_scale) { + float iscale = -32.f/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int8_t l = nearest_int(iscale*scales[j]); + l = MAX(-32, MIN(31, l)) + 32; + if (j < 8) { + y[i].scales[j] = l & 0xF; + } else { + y[i].scales[j-8] |= ((l & 0xF) << 4); + } + l >>= 4; + y[i].scales[j%4 + 8] |= (l << (2*(j/4))); + } + y[i].d = GGML_FP32_TO_FP16(1/iscale); + } else { + y[i].d = GGML_FP32_TO_FP16(0.f); + } + + int8_t sc; + for (int j = 0; j < QK_K/16; ++j) { + sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4; + sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32; + float d = GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-4, MIN(3, l)); + L[16*j + ii] = l + 4; + } + } +#else + if (max_scale) { + float iscale = -8.f/max_scale; + for (int j = 0; j < QK_K/16; j+=2) { + int l1 = nearest_int(iscale*scales[j]); + l1 = 8 + MAX(-8, MIN(7, l1)); + int l2 = nearest_int(iscale*scales[j+1]); + l2 = 8 + MAX(-8, MIN(7, l2)); + y[i].scales[j/2] = l1 | (l2 << 4); + } + y[i].d = GGML_FP32_TO_FP16(1/iscale); + } else { + for (int j = 0; j < QK_K/16; j+=2) { + y[i].scales[j/2] = 0; + } + y[i].d = GGML_FP32_TO_FP16(0.f); + } + for (int j = 0; j < QK_K/16; ++j) { + int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4; + float d = GGML_FP16_TO_FP32(y[i].d) * (s - 8); + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-4, MIN(3, l)); + L[16*j + ii] = l + 4; + } + } +#endif + + memset(y[i].hmask, 0, QK_K/8); + // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc. + int m = 0; + uint8_t hm = 1; + for (int j = 0; j < QK_K; ++j) { + if (L[j] > 3) { + y[i].hmask[m] |= hm; + L[j] -= 4; + } + if (++m == QK_K/8) { + m = 0; hm <<= 1; + } + } +#if QK_K == 256 + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } +#else + for (int l = 0; l < 16; ++l) { + y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); + } +#endif + + x += QK_K; + } +} + +#if QK_K == 256 +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + uint32_t aux[4]; + const int8_t * scales = (const int8_t*)aux; + + for (int i = 0; i < nb; i++) { + + const float d_all = GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + uint8_t m = 1; + + memcpy(aux, x[i].scales, 12); + uint32_t tmp = aux[2]; + aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + int is = 0; + float dl; + for (int n = 0; n < QK_K; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)); + } + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)); + } + + shift += 2; + m <<= 1; + } + q += 32; + } + + } +} +#else +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + assert(QK_K == 64); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d_all = GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + + const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); + const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); + const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); + const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); + + for (int l=0; l<8; ++l) { + uint8_t h = hm[l]; + y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4)); + y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4)); + y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4)); + y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4)); + y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4)); + y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4)); + y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4)); + y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4)); + } + y += QK_K; + } +} +#endif + +void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) { + quantize_row_q3_K_reference(x, vy, k); +} + +size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + (void)hist; // TODO: collect histograms + + for (int j = 0; j < n; j += k) { + block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K; + quantize_row_q3_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q3_K)); +} + +// ====================== 4-bit (de)-quantization + +void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + uint8_t L[QK_K]; + uint8_t Laux[32]; + float weights[32]; + float mins[QK_K/32]; + float scales[QK_K/32]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/32; ++j) { + //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); + float sum_x2 = 0; + for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l]; + float av_x = sqrtf(sum_x2/32); + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); + scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + +#if QK_K == 256 + float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; + float inv_min = max_min > 0 ? 63.f/max_min : 0.f; + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = nearest_int(inv_scale*scales[j]); + uint8_t lm = nearest_int(inv_min*mins[j]); + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = GGML_FP32_TO_FP16(max_scale/63.f); + y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) continue; + const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(15, l)); + L[32*j + ii] = l; + } + } +#else + const float s_factor = 15.f; + float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f; + float inv_min = max_min > 0 ? s_factor/max_min : 0.f; + int d1 = nearest_int(inv_scale*scales[0]); + int m1 = nearest_int(inv_min*mins[0]); + int d2 = nearest_int(inv_scale*scales[1]); + int m2 = nearest_int(inv_min*mins[1]); + y[i].scales[0] = d1 | (m1 << 4); + y[i].scales[1] = d2 | (m2 << 4); + y[i].d[0] = GGML_FP32_TO_FP16(max_scale/s_factor); + y[i].d[1] = GGML_FP32_TO_FP16(max_min/s_factor); + + float sumlx = 0; + int suml2 = 0; + for (int j = 0; j < QK_K/32; ++j) { + const uint8_t sd = y[i].scales[j] & 0xF; + const uint8_t sm = y[i].scales[j] >> 4; + const float d = GGML_FP16_TO_FP32(y[i].d[0]) * sd; + if (!d) continue; + const float m = GGML_FP16_TO_FP32(y[i].d[1]) * sm; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + m)/d); + l = MAX(0, MIN(15, l)); + L[32*j + ii] = l; + sumlx += (x[32*j + ii] + m)*l*sd; + suml2 += l*l*sd*sd; + } + } + if (suml2) { + y[i].d[0] = GGML_FP32_TO_FP16(sumlx/suml2); + } +#endif + uint8_t * q = y[i].qs; + for (int j = 0; j < QK_K; j += 64) { + for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4); + q += 32; + } + + x += QK_K; + + } +} + +void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const uint8_t * q = x[i].qs; + +#if QK_K == 256 + + const float d = GGML_FP16_TO_FP32(x[i].d); + const float min = GGML_FP16_TO_FP32(x[i].dmin); + + int is = 0; + uint8_t sc, m; + for (int j = 0; j < QK_K; j += 64) { + get_scale_min_k4(is + 0, x[i].scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, x[i].scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; + q += 32; is += 2; + } +#else + const float dall = GGML_FP16_TO_FP32(x[i].d[0]); + const float mall = GGML_FP16_TO_FP32(x[i].d[1]); + const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4); + const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4); + for (int l = 0; l < 32; ++l) { + y[l+ 0] = d1 * (q[l] & 0xF) - m1; + y[l+32] = d2 * (q[l] >> 4) - m2; + } + y += QK_K; +#endif + + } +} + +void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_q4_K * restrict y = vy; + quantize_row_q4_K_reference(x, y, k); +} + +size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + assert(k % QK_K == 0); + (void)hist; // TODO: collect histograms + + for (int j = 0; j < n; j += k) { + block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K; + quantize_row_q4_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q4_K)); +} + +// ====================== 5-bit (de)-quantization + +void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + +#if QK_K == 256 + uint8_t L[QK_K]; + float mins[QK_K/32]; + float scales[QK_K/32]; + float weights[32]; + uint8_t Laux[32]; +#else + int8_t L[QK_K]; + float scales[QK_K/16]; +#endif + + for (int i = 0; i < nb; i++) { + +#if QK_K == 256 + + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/32; ++j) { + //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); + float sum_x2 = 0; + for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l]; + float av_x = sqrtf(sum_x2/32); + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); + scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; + float inv_min = max_min > 0 ? 63.f/max_min : 0.f; + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = nearest_int(inv_scale*scales[j]); + uint8_t lm = nearest_int(inv_min*mins[j]); + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = GGML_FP32_TO_FP16(max_scale/63.f); + y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) continue; + const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(31, l)); + L[32*j + ii] = l; + } + } + + uint8_t * restrict qh = y[i].qh; + uint8_t * restrict ql = y[i].qs; + memset(qh, 0, QK_K/8); + + uint8_t m1 = 1, m2 = 2; + for (int n = 0; n < QK_K; n += 64) { + for (int j = 0; j < 32; ++j) { + int l1 = L[n + j]; + if (l1 > 15) { + l1 -= 16; qh[j] |= m1; + } + int l2 = L[n + j + 32]; + if (l2 > 15) { + l2 -= 16; qh[j] |= m2; + } + ql[j] = l1 | (l2 << 4); + } + m1 <<= 2; m2 <<= 2; + ql += 32; + } +#else + float max_scale = 0, amax = 0; + for (int j = 0; j < QK_K/16; ++j) { + scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1); + float abs_scale = fabsf(scales[j]); + if (abs_scale > amax) { + amax = abs_scale; + max_scale = scales[j]; + } + } + + float iscale = -128.f/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*scales[j]); + y[i].scales[j] = MAX(-128, MIN(127, l)); + } + y[i].d = GGML_FP32_TO_FP16(1/iscale); + + for (int j = 0; j < QK_K/16; ++j) { + const float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j]; + if (!d) continue; + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-16, MIN(15, l)); + L[16*j + ii] = l + 16; + } + } + + uint8_t * restrict qh = y[i].qh; + uint8_t * restrict ql = y[i].qs; + memset(qh, 0, QK_K/8); + + for (int j = 0; j < 32; ++j) { + int jm = j%8; + int is = j/8; + int l1 = L[j]; + if (l1 > 15) { + l1 -= 16; qh[jm] |= (1 << is); + } + int l2 = L[j + 32]; + if (l2 > 15) { + l2 -= 16; qh[jm] |= (1 << (4 + is)); + } + ql[j] = l1 | (l2 << 4); + } +#endif + + x += QK_K; + + } +} + +void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const uint8_t * ql = x[i].qs; + const uint8_t * qh = x[i].qh; + +#if QK_K == 256 + + const float d = GGML_FP16_TO_FP32(x[i].d); + const float min = GGML_FP16_TO_FP32(x[i].dmin); + + int is = 0; + uint8_t sc, m; + uint8_t u1 = 1, u2 = 2; + for (int j = 0; j < QK_K; j += 64) { + get_scale_min_k4(is + 0, x[i].scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, x[i].scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2; + ql += 32; is += 2; + u1 <<= 2; u2 <<= 2; + } +#else + float d = GGML_FP16_TO_FP32(x[i].d); + const int8_t * restrict s = x[i].scales; + for (int l = 0; l < 8; ++l) { + y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16)); + y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16)); + y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16)); + y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16)); + y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16)); + y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16)); + y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16)); + y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16)); + } + y += QK_K; +#endif + } +} + +void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_q5_K * restrict y = vy; + quantize_row_q5_K_reference(x, y, k); +} + +size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + assert(k % QK_K == 0); + (void)hist; // TODO: collect histograms + + for (int j = 0; j < n; j += k) { + block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K; + quantize_row_q5_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q5_K)); +} + +// ====================== 6-bit (de)-quantization + +void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + int8_t L[QK_K]; + float scales[QK_K/16]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; + float max_abs_scale = 0; + + for (int ib = 0; ib < QK_K/16; ++ib) { + + const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1); + scales[ib] = scale; + + const float abs_scale = fabsf(scale); + if (abs_scale > max_abs_scale) { + max_abs_scale = abs_scale; + max_scale = scale; + } + + } + + if (!max_abs_scale) { + memset(&y[i], 0, sizeof(block_q6_K)); + y[i].d = GGML_FP32_TO_FP16(0.f); + x += QK_K; + continue; + } + + float iscale = -128.f/max_scale; + y[i].d = GGML_FP32_TO_FP16(1/iscale); + for (int ib = 0; ib < QK_K/16; ++ib) { + y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib])); + } + + for (int j = 0; j < QK_K/16; ++j) { + float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j]; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-32, MIN(31, l)); + L[16*j + ii] = l + 32; + } + } + + uint8_t * restrict ql = y[i].ql; + uint8_t * restrict qh = y[i].qh; +#if QK_K == 256 + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + const uint8_t q1 = L[j + l + 0] & 0xF; + const uint8_t q2 = L[j + l + 32] & 0xF; + const uint8_t q3 = L[j + l + 64] & 0xF; + const uint8_t q4 = L[j + l + 96] & 0xF; + ql[l+ 0] = q1 | (q3 << 4); + ql[l+32] = q2 | (q4 << 4); + qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6); + } + ql += 64; + qh += 32; + } +#else + for (int l = 0; l < 32; ++l) { + const uint8_t q1 = L[l + 0] & 0xF; + const uint8_t q2 = L[l + 32] & 0xF; + ql[l] = q1 | (q2 << 4); + } + for (int l = 0; l < 16; ++l) { + qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6); + } +#endif + + x += QK_K; + + } +} + +void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict ql = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict sc = x[i].scales; + +#if QK_K == 256 + for (int n = 0; n < QK_K; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l + 0] = d * sc[is + 0] * q1; + y[l + 32] = d * sc[is + 2] * q2; + y[l + 64] = d * sc[is + 4] * q3; + y[l + 96] = d * sc[is + 6] * q4; + } + y += 128; + ql += 64; + qh += 32; + sc += 8; + } +#else + for (int l = 0; l < 16; ++l) { + const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l+ 0] = d * sc[0] * q1; + y[l+16] = d * sc[1] * q2; + y[l+32] = d * sc[2] * q3; + y[l+48] = d * sc[3] * q4; + } + y += 64; +#endif + + } +} + +void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_q6_K * restrict y = vy; + quantize_row_q6_K_reference(x, y, k); +} + +size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK_K == 0); + (void)hist; // TODO: collect histograms + + for (int j = 0; j < n; j += k) { + block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K; + quantize_row_q6_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q6_K)); +} + +//===================================== Q8_K ============================================== + +void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + float max = 0; + float amax = 0; + for (int j = 0; j < QK_K; ++j) { + float ax = fabsf(x[j]); + if (ax > amax) { + amax = ax; max = x[j]; + } + } + if (!amax) { + y[i].d = 0; + memset(y[i].qs, 0, QK_K); + x += QK_K; + continue; + } + const float iscale = -128.f/max; + for (int j = 0; j < QK_K; ++j) { + int v = nearest_int(iscale*x[j]); + y[i].qs[j] = MIN(127, v); + } + for (int j = 0; j < QK_K/16; ++j) { + int sum = 0; + for (int ii = 0; ii < 16; ++ii) { + sum += y[i].qs[j*16 + ii]; + } + y[i].bsums[j] = sum; + } + y[i].d = 1/iscale; + x += QK_K; + } +} + +void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK_K; ++j) { + *y++ = x[i].d * x[i].qs[j]; + } + } +} + +void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) { + quantize_row_q8_K_reference(x, y, k); +} + +//===================================== Dot ptoducts ================================= + +// +// Helper functions +// +#if __AVX__ || __AVX2__ || __AVX512F__ + +// shuffles to pick the required scales in dot products +static inline __m256i get_scale_shuffle_q3k(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m256i get_scale_shuffle_k4(int i) { + static const uint8_t k_shuffle[256] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m128i get_scale_shuffle(int i) { + static const uint8_t k_shuffle[128] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, + 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, + 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, + 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 + }; + return _mm_loadu_si128((const __m128i*)k_shuffle + i); +} +#endif + +void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + + const block_q4_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + assert(nb % 2 == 0); // TODO: handle odd nb + + for (int i = 0; i < nb; i += 2) { + const block_q4_0 * restrict x0 = &x[i + 0]; + const block_q4_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + // dot product into int32x4_t + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + bx = _mm256_sub_epi8( bx, off ); + + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps( d, q, acc ); + } + + *s = hsum_float_8(acc); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); + + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); + + const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); + + __m128i bx = _mm_and_si128(lowMask, tmp); + __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs); + bx = _mm_sub_epi8(bx, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx, by); + + bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); + by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx = _mm_sub_epi8(bx, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx, by); + + // Convert int32_t to float + __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); + + // Apply the scale, and accumulate + acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); + } + + *s = hsum_float_8(acc); +#elif defined(__SSSE3__) + // set constants + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); + + // Initialize accumulator with zeros + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); + + // First round without accumulation + { + _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) ); + + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); + + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) ); + + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); + + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); + + // Apply the scale + acc_0 = _mm_mul_ps( d_0_1, p0 ); + acc_1 = _mm_mul_ps( d_0_1, p1 ); + acc_2 = _mm_mul_ps( d_2_3, p2 ); + acc_3 = _mm_mul_ps( d_2_3, p3 ); + } + + assert(nb % 2 == 0); // TODO: handle odd nb + + // Main loop + for (int i = 2; i < nb; i+=2) { + _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); + + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); + + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) ); + + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); + + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); + + // Apply the scale + __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); + __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); + __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); + __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); + + // Acummulate + acc_0 = _mm_add_ps(p0_d, acc_0); + acc_1 = _mm_add_ps(p1_d, acc_1); + acc_2 = _mm_add_ps(p2_d, acc_2); + acc_3 = _mm_add_ps(p3_d, acc_3); + } + + *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (int i = 0; i < nb; i++) { + // load elements + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + + // mask and store lower part of x, and then upper part + vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + // subtract offset + vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl); + vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); + } + + *s = sumf; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[i].qs[j] & 0x0F) - 8; + const int v1 = (x[i].qs[j] >> 4) - 8; + + sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + } + + sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); + } + + *s = sumf; +#endif +} + +void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + + const block_q4_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; + + // TODO: add WASM SIMD +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs = 0; + + assert(nb % 2 == 0); // TODO: handle odd nb + + for (int i = 0; i < nb; i += 2) { + const block_q4_1 * restrict x0 = &x[i + 0]; + const block_q4_1 * restrict x1 = &x[i + 1]; + const block_q8_1 * restrict y0 = &y[i + 0]; + const block_q8_1 * restrict y1 = &y[i + 1]; + + summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + // dot product into int32x4_t + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; +#elif defined(__AVX2__) || defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + // Main loop + for (int i = 0; i < nb; ++i) { + const float d0 = GGML_FP16_TO_FP32(x[i].d); + const float d1 = y[i].d; + + summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; + + const __m256 d0v = _mm256_set1_ps( d0 ); + const __m256 d1v = _mm256_set1_ps( d1 ); + + // Compute combined scales + const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + const __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); + + const __m256 xy = mul_sum_us8_pairs_float(bx, by); + + // Accumulate d0*d1*x*y +#if defined(__AVX2__) + acc = _mm256_fmadd_ps( d0d1, xy, acc ); +#else + acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); +#endif + } + + *s = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (int i = 0; i < nb; i++) { + // load elements + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + + // mask and store lower part of x, and then upper part + vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[i].qs[j] & 0x0F); + const int v1 = (x[i].qs[j] >> 4); + + sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + } + + sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; +#endif +} + +void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(qk == QK5_0); + + const block_q5_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + uint32_t qh0; + uint32_t qh1; + + uint64_t tmp0[4]; + uint64_t tmp1[4]; + + assert(nb % 2 == 0); // TODO: handle odd nb + + for (int i = 0; i < nb; i += 2) { + const block_q5_0 * restrict x0 = &x[i]; + const block_q5_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + // extract the 5th bit via lookup table ((!b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_1[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_1[(qh1 >> 24) ]; + + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__wasm_simd128__) + v128_t sumv = wasm_f32x4_splat(0.0f); + + uint32_t qh; + uint64_t tmp[4]; + + // TODO: check if unrolling this is better + for (int i = 0; i < nb; ++i) { + const block_q5_0 * restrict x0 = &x[i]; + const block_q8_0 * restrict y0 = &y[i]; + + const v128_t m4b = wasm_i8x16_splat(0x0F); + + // extract the 5th bit + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_1[(qh >> 24) ]; + + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); + + const v128_t v0 = wasm_v128_load(x0->qs); + + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); + + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); + const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); + + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); + + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + + // dot product + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( + wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); + } + + *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; i++) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); + bx = _mm256_or_si256(bx, bxhi); + + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps(d, q, acc); + } + + *s = hsum_float_8(acc); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8((char)0xF0); + + // Main loop + for (int i = 0; i < nb; i++) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_andnot_si128(bxhil, mask); + bxhih = _mm_andnot_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx); + __m128i bxh = _mm256_extractf128_si256(bx, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx = MM256_SET_M128I(bxh, bxl); + + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); + } + + *s = hsum_float_8(acc); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + uint32_t qh; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + // These temporary registers are for masking and shift operations + vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); + vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); + + vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); + vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + + for (int i = 0; i < nb; i++) { + memcpy(&qh, x[i].qh, sizeof(uint32_t)); + + // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); + vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl); + vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); + + // ((qh & (1u << (j + 16))) >> (j + 12)); + vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl); + vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl); + + // narrowing + vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl); + vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); + + vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl); + vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); + + // load + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + + vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); + vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); + + vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl); + vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; + } + + *s = sumf; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; + const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; + + sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + } + + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; + } + + *s = sumf; +#endif +} + +void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(qk == QK5_1); + + const block_q5_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs0 = 0.0f; + float summs1 = 0.0f; + + uint32_t qh0; + uint32_t qh1; + + uint64_t tmp0[4]; + uint64_t tmp1[4]; + + assert(nb % 2 == 0); // TODO: handle odd nb + + for (int i = 0; i < nb; i += 2) { + const block_q5_1 * restrict x0 = &x[i]; + const block_q5_1 * restrict x1 = &x[i + 1]; + const block_q8_1 * restrict y0 = &y[i]; + const block_q8_1 * restrict y1 = &y[i + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s; + summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s; + + // extract the 5th bit via lookup table ((b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_0[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_0[(qh1 >> 24) ]; + + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // add high bit + const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; +#elif defined(__wasm_simd128__) + v128_t sumv = wasm_f32x4_splat(0.0f); + + float summs = 0.0f; + + uint32_t qh; + uint64_t tmp[4]; + + // TODO: check if unrolling this is better + for (int i = 0; i < nb; ++i) { + const block_q5_1 * restrict x0 = &x[i]; + const block_q8_1 * restrict y0 = &y[i]; + + summs += GGML_FP16_TO_FP32(x0->m) * y0->s; + + const v128_t m4b = wasm_i8x16_splat(0x0F); + + // extract the 5th bit + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_0[(qh >> 24) ]; + + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); + + const v128_t v0 = wasm_v128_load(x0->qs); + + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); + + // add high bit + const v128_t v0lf = wasm_v128_or(v0l, qhl); + const v128_t v0hf = wasm_v128_or(v0h, qhh); + + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); + + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + + // dot product + sumv = wasm_f32x4_add(sumv, + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d))); + } + + *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.0f; + + // Main loop + for (int i = 0; i < nb; i++) { + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); + + summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); + bx = _mm256_or_si256(bx, bxhi); + + const __m256 dy = _mm256_set1_ps(y[i].d); + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_us8_pairs_float(bx, by); + + acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); + } + + *s = hsum_float_8(acc) + summs; +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8(0x10); + + float summs = 0.0f; + + // Main loop + for (int i = 0; i < nb; i++) { + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); + + summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_and_si128(bxhil, mask); + bxhih = _mm_and_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx); + __m128i bxh = _mm256_extractf128_si256(bx, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx = MM256_SET_M128I(bxh, bxl); + + const __m256 dy = _mm256_set1_ps(y[i].d); + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_us8_pairs_float(bx, by); + + acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); + } + + *s = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + uint32_t qh; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + // temporary registers for shift operations + vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); + vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + + for (int i = 0; i < nb; i++) { + memcpy(&qh, x[i].qh, sizeof(uint32_t)); + + // load qh + vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); + + // ((qh >> (j + 0)) << 4) & 0x10; + vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl); + vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); + vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl); + + // ((qh >> (j + 12)) ) & 0x10; + vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl); + vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl); + + // narrowing + vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl); + vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); + + vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl); + vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); + + // load + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + + vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); + vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); + + vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[i].qs[j] >> 4) | xh_1; + + sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + } + + sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; +#endif +} + +void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + + const block_q8_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + assert(nb % 2 == 0); // TODO: handle odd nb + + for (int i = 0; i < nb; i += 2) { + const block_q8_0 * restrict x0 = &x[i + 0]; + const block_q8_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const int8x16_t x0_0 = vld1q_s8(x0->qs); + const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); + const int8x16_t x1_0 = vld1q_s8(x1->qs); + const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); + + // load y + const int8x16_t y0_0 = vld1q_s8(y0->qs); + const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); + const int8x16_t y1_0 = vld1q_s8(y1->qs); + const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), + vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), + vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + +#else + const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0)); + const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); + const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1)); + const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); + + const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0)); + const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0)); + const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1)); + const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1)); + + const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); + const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); + const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); + const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) || defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); + __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs); + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + // Multiply q with scale and accumulate +#if defined(__AVX2__) + acc = _mm256_fmadd_ps( d, q, acc ); +#else + acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc ); +#endif + } + + *s = hsum_float_8(acc); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + size_t vl = __riscv_vsetvl_e8m1(qk); + + for (int i = 0; i < nb; i++) { + // load elements + vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl); + vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl); + + vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); + + sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); + } + + *s = sumf; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + int sumi = 0; + + for (int j = 0; j < qk; j++) { + sumi += x[i].qs[j]*y[i].qs[j]; + } + + sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); + } + + *s = sumf; +#endif +} + +#if QK_K == 256 +void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + + const block_q2_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m3 = vdupq_n_u8(0x3); + const uint8x16_t m4 = vdupq_n_u8(0xF); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + ggml_int8x16x2_t q2bytes; + uint8_t aux[16]; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint8_t * restrict sc = x[i].scales; + + const uint8x16_t mins_and_scales = vld1q_u8(sc); + const uint8x16_t scales = vandq_u8(mins_and_scales, m4); + vst1q_u8(aux, scales); + + const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); + const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); + const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; + const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), + vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); + const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), + vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); + sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); + + int isum = 0; + int is = 0; + +// We use this macro instead of a function call because for some reason +// the code runs 2-3% slower, even if the function is declared inline +#if defined(__ARM_FEATURE_DOTPROD) +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; +#else +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + {\ + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\ + vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\ + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\ + vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\ + isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\ + } +#endif + +#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\ + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ + MULTIPLY_ACCUM_WITH_SCALE((index)); + + + for (int j = 0; j < QK_K/128; ++j) { + + const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32; + + ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); + MULTIPLY_ACCUM_WITH_SCALE(0); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); + + is += 8; + } + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m128i m4 = _mm_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m256i mins = _mm256_cvtepi8_epi16(mins8); + const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); + + const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32; + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i q2_0 = _mm256_and_si256(q2bits, m3); + const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); + const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); + const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); + + __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); + __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); + + p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); + p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); + p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); + p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); + + p0 = _mm256_add_epi32(p0, p1); + p2 = _mm256_add_epi32(p2, p3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); + } + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(0x3); + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // load mins and scales from block_q2_K.scales[QK_K/16] + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales16 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m128i mins_0 = _mm_cvtepi8_epi16(mins16); + const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16)); + + // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2 + const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0])); + const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8])); + + // sumf += -dmin * summs in 32bits*8 + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc); + + const __m128i scales_0 = _mm_cvtepi8_epi16(scales16); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16)); + const __m128i scales[2] = { scales_0, scales_1 }; + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + for (int j = 0; j < QK_K/128; ++j) { + + // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // load 2bits*16*8 from block_q2_K.qs[QK_K/4] + __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_0 = _mm_and_si128(q2bits, m3); + const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_1 = _mm_and_si128(q2bits, m3); + const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + + // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8 + __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0); + __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1); + __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2); + __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3); + __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4); + __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5); + __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6); + __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7); + + // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8 + __m128i shuffle = _mm_set1_epi16(0x0100); + p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0); + shuffle = _mm_add_epi16(shuffle, m2); + p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1); + shuffle = _mm_add_epi16(shuffle, m2); + p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2); + shuffle = _mm_add_epi16(shuffle, m2); + p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3); + shuffle = _mm_add_epi16(shuffle, m2); + p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4); + shuffle = _mm_add_epi16(shuffle, m2); + p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5); + shuffle = _mm_add_epi16(shuffle, m2); + p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6); + shuffle = _mm_add_epi16(shuffle, m2); + p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7); + + p0 = _mm_add_epi32(p0, p1); + p2 = _mm_add_epi32(p2, p3); + p4 = _mm_add_epi32(p4, p5); + p6 = _mm_add_epi32(p6, p7); + + // isum in 32bits*4*2 + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6)); + } + + // sumf += dall * isum - dmin * summs in 32bits + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + float sumf = 0; + uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + size_t vl = 16; + + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + + vl = 32; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + + uint8_t is=0; + int isum=0; + + for (int j = 0; j < QK_K/128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl); + + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl); + + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl); + + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(isum1); + + q2+=32; q8+=128; is=8; + + } + + sumf += dall * isum; + + } + + *s = sumf; + +#else + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < 16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + int isum = 0; + int is = 0; + int d; + for (int k = 0; k < QK_K/128; ++k) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + d = sc[is++] & 0xF; + int isuml = 0; + for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + d = sc[is++] & 0xF; + isuml = 0; + for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + shift += 2; + q8 += 32; + } + q2 += 32; + } + sumf += dall * isum - dmin * summs; + } + *s = sumf; +#endif +} + +#else + +void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + + const block_q2_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m3 = vdupq_n_u8(0x3); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + ggml_int8x16x4_t q2bytes; + + uint32_t aux32[2]; + const uint8_t * scales = (const uint8_t *)aux32; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * (float)x[i].d; + const float dmin = -y[i].d * (float)x[i].dmin; + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + + aux32[0] = sc[0] & 0x0f0f0f0f; + aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; + + sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); + + int isum1 = 0, isum2 = 0; + + const uint8x16_t q2bits = vld1q_u8(q2); + + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); + + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3)); + q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3)); + q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3)); + +#if defined(__ARM_FEATURE_DOTPROD) + isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0]; + isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1]; + isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2]; + isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3]; +#else + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum1 += vaddvq_s16(p1) * scales[0]; + isum2 += vaddvq_s16(p2) * scales[1]; + + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum1 += vaddvq_s16(p3) * scales[2]; + isum2 += vaddvq_s16(p4) * scales[3]; +#endif + sum += d * (isum1 + isum2); + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t ud, um; + const uint8_t * restrict db = (const uint8_t *)&ud; + const uint8_t * restrict mb = (const uint8_t *)&um; + + float summs = 0; + + // TODO: optimize this + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + ud = (sc[0] >> 0) & 0x0f0f0f0f; + um = (sc[0] >> 4) & 0x0f0f0f0f; + + int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; + summs += dmin * smin; + + const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); + const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3); + const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + + const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0)); + const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1)); + const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0)); + const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1)); + + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc); + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t ud, um; + const uint8_t * restrict db = (const uint8_t *)&ud; + const uint8_t * restrict mb = (const uint8_t *)&um; + + float summs = 0; + + // TODO: optimize this + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + ud = (sc[0] >> 0) & 0x0f0f0f0f; + um = (sc[0] >> 4) & 0x0f0f0f0f; + + int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; + summs += dmin * smin; + + const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); + const __m128i q2_0 = _mm_and_si128(q2bits, m3); + const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1)); + + const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0)); + const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1)); + const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2)); + const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3)); + + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc); + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __riscv_v_intrinsic + + uint32_t aux32[2]; + const uint8_t * scales = (const uint8_t *)aux32; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * (float)x[i].d; + const float dmin = -y[i].d * (float)x[i].dmin; + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + + aux32[0] = sc[0] & 0x0f0f0f0f; + aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; + + sumf += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); + + int isum1 = 0; + int isum2 = 0; + + size_t vl = 16; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + // load Q2 + vuint8mf2_t q2_x = __riscv_vle8_v_u8mf2(q2, vl); + + vint8mf2_t q2_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q2_x, 0x03, vl)); + vint8mf2_t q2_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x2, vl), 0x03 , vl)); + vint8mf2_t q2_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x4, vl), 0x03 , vl)); + vint8mf2_t q2_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x6, vl), 0x03 , vl)); + + // load Q8, and take product with Q2 + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q2_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q2_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q2_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q2_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m1_i16m1(p0, vzero, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m1_i16m1(p1, vzero, vl); + vint16m1_t vs_2 = __riscv_vredsum_vs_i16m1_i16m1(p2, vzero, vl); + vint16m1_t vs_3 = __riscv_vredsum_vs_i16m1_i16m1(p3, vzero, vl); + + isum1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[0]; + isum2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[1]; + isum1 += __riscv_vmv_x_s_i16m1_i16(vs_2) * scales[2]; + isum2 += __riscv_vmv_x_s_i16m1_i16(vs_3) * scales[3]; + + sumf += d * (isum1 + isum2); + + } + + *s = sumf; + +#else + + float sumf = 0; + + int isum[4]; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < QK_K/16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + isum[0] = isum[1] = isum[2] = isum[3] = 0; + for (int l = 0; l < 16; ++l) { + isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3); + isum[1] += q8[l+16] * ((q2[l] >> 2) & 3); + isum[2] += q8[l+32] * ((q2[l] >> 4) & 3); + isum[3] += q8[l+48] * ((q2[l] >> 6) & 3); + } + for (int l = 0; l < 4; ++l) { + isum[l] *= (sc[l] & 0xF); + } + sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs; + } + *s = sumf; +#endif +} +#endif + +#if QK_K == 256 +void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + uint32_t aux[3]; + uint32_t utmp[4]; + + const uint8x16_t m3b = vdupq_n_u8(0x3); +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + const uint8x16_t m0 = vdupq_n_u8(1); + const uint8x16_t m1 = vshlq_n_u8(m0, 1); + const uint8x16_t m2 = vshlq_n_u8(m0, 2); + const uint8x16_t m3 = vshlq_n_u8(m0, 3); + const int8_t m32 = 32; + + ggml_int8x16x4_t q3bytes; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); + + ggml_uint8x16x4_t q3h; + + int32_t isum = 0; + + // Set up scales + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= m32; + + for (int j = 0; j < QK_K/128; ++j) { + + const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32; + const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64; + const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64; + + q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); + q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); + q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); + q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; +#else + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1]))); + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; +#endif + scale += 4; + + q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); + q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); + q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); + q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; +#else + p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0]))); + p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1]))); + p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2]))); + p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; +#endif + scale += 4; + + if (j == 0) { + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); + } + + } + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i mone = _mm256_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Set up scales + memcpy(aux, x[i].scales, 12); + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + + // high bit + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); + + // integer accumulator + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits + const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; + + // prepare low and high bits + const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); + const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); + const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); + const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); + const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + + // accumulate + p16_0 = _mm256_add_epi32(p16_0, p16_1); + p16_2 = _mm256_add_epi32(p16_2, p16_3); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); + + } + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + const uint32_t *aux; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Set up scales + aux = (const uint32_t *)x[i].scales; + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m128i scales_0 = _mm_cvtepi8_epi16(scales128); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128)); + const __m128i scales[2] = { scales_0, scales_1 }; + + // high bit *128*2 from block_q3_K.hmask[QK_K/8] + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]); + + // integer accumulator + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4] + const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + + // prepare low and high bits + const int bit = j << 2; + + const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3); + const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3); + const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); + const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); + + const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3); + const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3); + const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + + const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3); + const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3); + const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + + const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3); + const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3); + const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + + // load Q8 quants from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0); + __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1); + __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2); + __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3); + __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4); + __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5); + __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6); + __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7); + + __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + p16_4 = _mm_sub_epi16(p16_4, q8s_4); + p16_5 = _mm_sub_epi16(p16_5, q8s_5); + p16_6 = _mm_sub_epi16(p16_6, q8s_6); + p16_7 = _mm_sub_epi16(p16_7, q8s_7); + + // multiply with scales + __m128i shuffle = _mm_set1_epi16(0x0100); + p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0); + shuffle = _mm_add_epi16(shuffle, m2); + p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1); + shuffle = _mm_add_epi16(shuffle, m2); + p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2); + shuffle = _mm_add_epi16(shuffle, m2); + p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3); + shuffle = _mm_add_epi16(shuffle, m2); + p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4); + shuffle = _mm_add_epi16(shuffle, m2); + p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5); + shuffle = _mm_add_epi16(shuffle, m2); + p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6); + shuffle = _mm_add_epi16(shuffle, m2); + p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7); + + // accumulate + p16_0 = _mm_add_epi32(p16_0, p16_1); + p16_2 = _mm_add_epi32(p16_2, p16_3); + p16_4 = _mm_add_epi32(p16_4, p16_5); + p16_6 = _mm_add_epi32(p16_6, p16_7); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6)); + + } + + // multiply with block scale and accumulate + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + uint32_t aux[3]; + uint32_t utmp[4]; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + + int sum_t = 0; + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product with Q3 + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + // retrieve lane to multiply with scale + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q3 += 32; q8 += 128; scale += 8; + + } + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + + } + + *s = sumf; + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it + // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the + // manually vectorized version above. Every other version I tried would run at least 4 times slower. + // The ideal situation would be if we could just write the code once, and the compiler would + // automatically produce the best possible set of machine instructions, instead of us having to manually + // write vectorized versions for AVX, ARM_NEON, etc. + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} + +#else + +void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q3_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + const uint8x16_t m3b = vdupq_n_u8(0x3); + const uint8x16_t mh = vdupq_n_u8(4); + + ggml_int8x16x4_t q3bytes; + + uint16_t aux16[2]; + int8_t * scales = (int8_t *)aux16; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + ggml_uint8x16x4_t q3h; + + const uint8x8_t hbits = vld1_u8(x[i].hmask); + const uint8x16_t q3bits = vld1q_u8(x[i].qs); + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs); + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + for (int j = 0; j < 4; ++j) scales[j] -= 8; + + int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + + const float d = y[i].d * (float)x[i].d; + + const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1)); + q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); + q3h.val[1] = vandq_u8(mh, htmp); + q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); + q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4)); + + q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); + q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); + q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2])); + q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3]; +#else + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3]; +#endif + + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i m1 = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + uint64_t aux64; + + uint16_t aux16[2]; + const int8_t * aux8 = (const int8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8)); + const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8)); + + memcpy(&aux64, x[i].hmask, 8); + + const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); + __m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux); + __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4); + q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2); + q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2); + + // load low 2 bits + const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + + // prepare low and high bits + const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits); + const __m256i q3l_0 = _mm256_and_si256(q3aux, m3); + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3); + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + + // multiply with scales + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + p16_0 = _mm256_add_epi32(p16_0, p16_1); + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i m1 = _mm_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + uint64_t aux64; + + uint16_t aux16[2]; + const int8_t * aux8 = (const int8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8); + const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8); + const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8); + const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8); + + memcpy(&aux64, x[i].hmask, 8); + + __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); + __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2); + __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4); + __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6); + q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2); + q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2); + q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2); + q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2); + + // load low 2 bits + const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + + // prepare low and high bits + const __m128i q3l_0 = _mm_and_si128(q3bits, m3); + const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3); + const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3); + const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3); + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1)); + + __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0)); + __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1)); + __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0)); + __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1)); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm_madd_epi16(scale_0, p16_0); + p16_1 = _mm_madd_epi16(scale_1, p16_1); + p16_2 = _mm_madd_epi16(scale_2, p16_2); + p16_3 = _mm_madd_epi16(scale_3, p16_3); + + p16_0 = _mm_add_epi32(p16_0, p16_2); + p16_1 = _mm_add_epi32(p16_1, p16_3); + __m256i p16 = MM256_SET_M128I(p16_1, p16_0); + + // multiply with block scale and accumulate + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + uint16_t aux16[2]; + int8_t * scales = (int8_t *)aux16; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + for (int j = 0; j < 4; ++j) scales[j] -= 8; + + int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + + const float d = y[i].d * (float)x[i].d; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + // load qh + vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(x[i].hmask, 8); + vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); + + size_t vl = 16; + + // extend and combine both qh_x1 and qh_x2 + vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); + + vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); + vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(qh_x, 0x4, vl); + vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); + vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x4, vl); + + // load Q3 + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl); + + vuint8mf2_t q3h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x3, vl), qh_0, vl); + vuint8mf2_t q3h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 2, vl), 0x3, vl), qh_1, vl); + vuint8mf2_t q3h_2 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 4, vl), 0x3, vl), qh_2, vl); + vuint8mf2_t q3h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), qh_3, vl); + + vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_0); + vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_1); + vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_2); + vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_3); + + // load Q8 and take product with Q3 + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q3_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q3_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q3_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q3_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + + vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); + vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); + vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); + vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scales[0]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scales[2]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scales[1]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scales[3]; + + sumf += d * isum; + + } + + *s = sumf; + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + int32_t scales[4]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + int8_t * restrict a = aux8; + for (int l = 0; l < 8; ++l) { + a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4); + a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4); + a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4); + a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4); + a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4); + a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4); + a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4); + a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4); + } + + scales[0] = (x[i].scales[0] & 0xF) - 8; + scales[1] = (x[i].scales[0] >> 4) - 8; + scales[2] = (x[i].scales[1] & 0xF) - 8; + scales[3] = (x[i].scales[1] >> 4) - 8; + + memset(aux32, 0, 8*sizeof(int32_t)); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l]; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} +#endif + +#if QK_K == 256 +void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q4_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t mzero = vdupq_n_s32(0); +#endif + + ggml_int8x16x2_t q4bytes; + ggml_int8x16x2_t q8bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + + uint32x2_t mins8 = { 0 }; + mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); + mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); + + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + sumf -= dmin * vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; + +#ifdef __ARM_FEATURE_DOTPROD + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + sumi1 += vaddvq_s32(p1) * scales[2*j+0]; + + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + + sumi2 += vaddvq_s32(p2) * scales[2*j+1]; +#else + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0]; + + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1]; + +#endif + } + + sumf += d * (sumi1 + sumi2); + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + p16l = _mm256_madd_epi16(scale_l, p16l); + + const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + p16h = _mm256_madd_epi16(scale_h, p16h); + const __m256i sumj = _mm256_add_epi32(p16l, p16h); + + sumi = _mm256_add_epi32(sumi, sumj); + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_0 = _mm_and_si128(q4bits, m4); + const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_1 = _mm_and_si128(q4bits, m4); + const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + + const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_0 = _mm_add_epi32(sumi_0, p16l); + const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16l = _mm_maddubs_epi16(q4l_1, q8l_1); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_1 = _mm_add_epi32(sumi_1, p16l); + + const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_0 = _mm_add_epi32(sumi_0, p16h); + const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16h = _mm_maddubs_epi16(q4h_1, q8h_1); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_1 = _mm_add_epi32(sumi_1, p16h); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#elif defined __riscv_v_intrinsic + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + size_t vl = 8; + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + vl = 32; + + int32_t sum_1 = 0; + int32_t sum_2 = 0; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + + q4 += 32; q8 += 64; + + } + + sumf += d*(sum_1 + sum_2); + + } + + *s = sumf; + +#else + + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#else +void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q4_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); + +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t mzero = vdupq_n_s32(0); +#endif + + float sumf = 0; + + ggml_int8x16x2_t q4bytes; + ggml_int8x16x4_t q8bytes; + + float sum_mins = 0.f; + + uint16_t aux16[2]; + const uint8_t * restrict scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t * restrict a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]); + sum_mins += y[i].d * (float)x[i].d[1] * summi; + + const float d = y[i].d * (float)x[i].d[0]; + + const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); + +#ifdef __ARM_FEATURE_DOTPROD + q8bytes = ggml_vld1q_s8_x4(q8); + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + const int32_t sumi1 = vaddvq_s32(p1) * scales[0]; + + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); + const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; + +#else + q8bytes = ggml_vld1q_s8_x4(q8); + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0]; + + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3]))); + int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1]; + +#endif + sumf += d * (sumi1 + sumi2); + + } + + *s = sumf - sum_mins; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + uint16_t aux16[2]; + const uint8_t * scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d; + const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d; + const __m256 vd = _mm256_set1_ps(d); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + + const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc); + + const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc); + + } + + *s = hsum_float_8(acc) - summs; + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + uint16_t aux16[2]; + const uint8_t * scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d; + const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d; + const __m256 vd = _mm256_set1_ps(d); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0); + const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1); + const __m128i q4_0 = _mm_and_si128(q4bits_0, m4); + const __m128i q4_1 = _mm_and_si128(q4bits_1, m4); + const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4); + const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); + + const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0); + const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc); + + const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2); + const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc); + + } + + *s = hsum_float_8(acc) - summs; + +#elif defined __riscv_v_intrinsic + + uint16_t s16[2]; + const uint8_t * restrict scales = (const uint8_t *)s16; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t * restrict b = (const uint16_t *)x[i].scales; + s16[0] = b[0] & 0x0f0f; + s16[1] = (b[0] >> 4) & 0x0f0f; + + sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); + + size_t vl = 32; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q4_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t va_0 = __riscv_vwmul_vv_i16m2(q4_a, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m1_t aux1 = __riscv_vredsum_vs_i16m2_i16m1(va_0, vzero, vl); + + sumf += d*scales[0]*__riscv_vmv_x_s_i16m1_i16(aux1); + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q4_s = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t va_1 = __riscv_vwmul_vv_i16m2(q4_s, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m1_t aux2 = __riscv_vredsum_vs_i16m2_i16m1(va_1, vzero, vl); + + sumf += d*scales[1]*__riscv_vmv_x_s_i16m1_i16(aux2); + + } + + *s = sumf; + +#else + + uint8_t aux8[QK_K]; + int16_t aux16[16]; + float sums [8]; + memset(sums, 0, 8*sizeof(float)); + + uint16_t s16[2]; + const uint8_t * restrict scales = (const uint8_t *)s16; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + uint8_t * restrict a = aux8; + for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF; + for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4; + + const uint16_t * restrict b = (const uint16_t *)x[i].scales; + s16[0] = b[0] & 0x0f0f; + s16[1] = (b[0] >> 4) & 0x0f0f; + + sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); + + for (int j = 0; j < QK_K/32; ++j) { + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + q8 += 16; a += 16; + for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; + q8 += 16; a += 16; + const float dl = d * scales[j]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]); + } + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#endif + +#if QK_K == 256 +void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q5_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t mzero = vdupq_n_s32(0); +#endif + + ggml_int8x16x4_t q5bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + int32_t sumi_mins = vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); + + ggml_uint8x16x4_t q5h; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32; + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; + + q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); + q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); + + q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); + q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); + q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); + q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; + sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; +#else + + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++; + + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++; +#endif + } + + sumf += d * sumi - dmin * sumi_mins; + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m256i mone = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + +#if QK_K == 256 + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; +#else + // TODO + const float d = 0, dmin = 0; +#endif + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh); + __m256i hmask = mone; + + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32; + + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); + + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]); + __m128i hmask = mone; + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + int bit = 0; + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + + __m128i q5l_0 = _mm_and_si128(q5bits_0, m4); + __m128i q5l_1 = _mm_and_si128(q5bits_1, m4); + __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0); + __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1); + p16_0 = _mm_madd_epi16(scale_0, p16_0); + p16_1 = _mm_madd_epi16(scale_0, p16_1); + + q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4); + q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4); + q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + q5_0 = _mm_add_epi8(q5l_0, q5h_0); + q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1); + p16_2 = _mm_madd_epi16(scale_1, p16_2); + p16_3 = _mm_madd_epi16(scale_1, p16_3); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __riscv_v_intrinsic + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + float sums = 0.0; + + size_t vl; + + for (int i = 0; i < nb; ++i) { + + vl = 8; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + vl = 32; + int32_t aux32 = 0; + int is = 0; + + uint8_t m = 1; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q5 and Q8 + vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl); + vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl); + + // compute mask for addition + vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl)); + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl); + m <<= 1; + + vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl)); + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl); + m <<= 1; + + vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl); + vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl); + + vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl); + vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl); + + vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl); + vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl); + + aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2); + q5 += 32; q8 += 64; + + } + + vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1); + sums += __riscv_vfmv_f_s_f32m1_f32(vaux); + + } + + *s = sumf+sums; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#else + +void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q5_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t mh = vdupq_n_u8(16); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t mzero = vdupq_n_s32(0); +#endif + + ggml_int8x16x4_t q5bytes; + ggml_uint8x16x4_t q5h; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * (float)x[i].d; + const int8_t * sc = x[i].scales; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const uint8x8_t qhbits = vld1_u8(qh); + + const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); + + const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); + q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4)); + q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2)); + q5h.val[2] = vbicq_u8(mh, htmp); + q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2)); + + q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0])); + q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1])); + q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2])); + q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0])); + int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1])); + int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2])); + int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3])); + + sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); + +#else + + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1); + + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3); + + sumf += d*sumi; +#endif + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i mone = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); + + const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0])); + const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2])); + + int64_t aux64; + memcpy(&aux64, x[i].qh, 8); + const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64); + const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128); + + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4); + + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0)); + const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1)); + const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0)); + const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1)); + + const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1)); + + acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i mone = _mm_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); + + const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]); + const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]); + const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]); + const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]); + + int64_t aux64; + memcpy(&aux64, x[i].qh, 8); + const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64); + const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2); + + const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4); + const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4); + const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4); + const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4); + + const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4); + const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4); + const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4); + const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0))); + const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1))); + const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0))); + const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1))); + const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0))); + const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1))); + const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0))); + const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1))); + + const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2)); + const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3)); + + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * (float)x[i].d; + const int8_t * sc = x[i].scales; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + // load qh + vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(qh, 8); + vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); + + size_t vl = 16; + + // combine both qh_1 and qh_2 + vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); + + vuint8mf2_t qh_h0 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); + vuint8mf2_t qh_h1 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), vl), 16, vl); + vuint8mf2_t qh_h2 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(qh_x, vl), 16, vl); + vuint8mf2_t qh_h3 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); + + vint8mf2_t qh_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h0); + vint8mf2_t qh_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h1); + vint8mf2_t qh_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h2); + vint8mf2_t qh_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h3); + + // load q5 + vuint8mf2_t q5_x1 = __riscv_vle8_v_u8mf2(q5, vl); + vuint8mf2_t q5_x2 = __riscv_vle8_v_u8mf2(q5+16, vl); + + vint8mf2_t q5s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x1, 0xF, vl)); + vint8mf2_t q5s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x2, 0xF, vl)); + vint8mf2_t q5s_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x1, 0x4, vl)); + vint8mf2_t q5s_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x2, 0x4, vl)); + + vint8mf2_t q5_0 = __riscv_vsub_vv_i8mf2(q5s_0, qh_0, vl); + vint8mf2_t q5_1 = __riscv_vsub_vv_i8mf2(q5s_1, qh_1, vl); + vint8mf2_t q5_2 = __riscv_vsub_vv_i8mf2(q5s_2, qh_2, vl); + vint8mf2_t q5_3 = __riscv_vsub_vv_i8mf2(q5s_3, qh_3, vl); + + // load Q8 and multiply it with Q5 + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q5_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q5_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q5_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q5_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + + vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); + vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); + vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); + vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + + int32_t sumi1 = sc[0] * __riscv_vmv_x_s_i32m1_i32(vs_0); + int32_t sumi2 = sc[1] * __riscv_vmv_x_s_i32m1_i32(vs_1); + int32_t sumi3 = sc[2] * __riscv_vmv_x_s_i32m1_i32(vs_2); + int32_t sumi4 = sc[3] * __riscv_vmv_x_s_i32m1_i32(vs_3); + + sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); + + } + + *s = sumf; + +#else + + int8_t aux8[QK_K]; + int16_t aux16[16]; + float sums [8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + int8_t * restrict a = aux8; + for (int l = 0; l < 32; ++l) { + a[l+ 0] = q4[l] & 0xF; + a[l+32] = q4[l] >> 4; + } + for (int is = 0; is < 8; ++is) { + uint8_t m = 1 << is; + for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16); + } + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const int8_t * restrict sc = x[i].scales; + + for (int j = 0; j < QK_K/16; ++j) { + const float dl = d * sc[j]; + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]); + q8 += 16; a += 16; + } + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#endif + + +#if QK_K == 256 +void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q6_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + float sum = 0; + + const uint8x16_t m4b = vdupq_n_u8(0xF); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t vzero = vdupq_n_s32(0); +#endif + //const int8x16_t m32s = vdupq_n_s8(32); + + const uint8x16_t mone = vdupq_n_u8(3); + + ggml_int8x16x4_t q6bytes; + ggml_uint8x16x4_t q6h; + + for (int i = 0; i < nb; ++i) { + + const float d_all = GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); + const int8x16_t scales = vld1q_s8(scale); + const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}; + + const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), + vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), + vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); + int32_t isum_mins = vaddvq_s32(prod); + + int32_t isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32; + ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64; + ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; + + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 2); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + scale += 4; + +#else + + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; + scale += 2; + + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; + scale += 2; +#endif + + q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; + + shifted = vshrq_n_u8(qhbits.val[0], 4); + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[0], 6); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + scale += 4; + + //for (int l = 0; l < 4; ++l) { + // const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]); + // isum += vaddvq_s32(p) * *scale++; + //} +#else + p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; + scale += 2; + + p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; + scale += 2; +#endif + + } + //sum += isum * d_all * y[i].d; + sum += d_all * y[i].d * (isum - 32 * isum_mins); + + } + *s = sum; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + + __m256i sumi = _mm256_setzero_si256(); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; + + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); + const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); + const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); + + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); + const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); + const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); + + } + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m3 = _mm_set1_epi8(3); + const __m128i m32s = _mm_set1_epi8(32); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; + const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; + + const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); + const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); + const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4); + const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4); + const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4); + const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4); + const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4); + const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4); + + const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + + const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0); + const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1); + const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2); + const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3); + const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4); + const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5); + const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6); + const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7); + + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0); + __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1); + __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2); + __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3); + __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4); + __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5); + __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6); + __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7); + + __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + p16_4 = _mm_sub_epi16(p16_4, q8s_4); + p16_5 = _mm_sub_epi16(p16_5, q8s_5); + p16_6 = _mm_sub_epi16(p16_6, q8s_6); + p16_7 = _mm_sub_epi16(p16_7, q8s_7); + + const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + + p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); + p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); + p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); + p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); + p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5); + p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); + p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); + + } + + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + size_t vl; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + int sum_t = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + vl = 32; + + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q6 += 64; qh += 32; q8 += 128; is=8; + + } + + sumf += d * sum_t; + + } + + *s = sumf; + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#else + +void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q6_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + float sum = 0; + + const uint8x16_t m4b = vdupq_n_u8(0xF); + const int8x16_t m32s = vdupq_n_s8(32); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + const uint8x16_t mone = vdupq_n_u8(3); + + ggml_int8x16x4_t q6bytes; + ggml_uint8x16x4_t q6h; + + for (int i = 0; i < nb; ++i) { + + const float d_all = (float)x[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int32_t isum = 0; + + uint8x16_t qhbits = vld1q_u8(qh); + ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6); + ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); + + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits, 2); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits, 4); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits, 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s); + q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s); + +#if defined(__ARM_FEATURE_DOTPROD) + + isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; +#else + + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; + + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; +#endif + + sum += isum * d_all * y[i].d; + + } + *s = sum; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); + const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); + const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); + const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); + + __m256i sumi = _mm256_setzero_si256(); + + const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); + const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); + + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4); + + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(3); + const __m128i m32s = _mm_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); + const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); + const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); + const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); + const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); + + const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4); + const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4); + const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4); + const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4); + + const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0); + const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1); + const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2); + const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0)); + __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1)); + __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0)); + __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1)); + + __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); + __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); + __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); + __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + + p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); + p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); + p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d_all = (float)x[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int32_t isum = 0; + + size_t vl = 16; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + // load Q6 + vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl); + vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+16, vl); + + // load qh + vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl); + + vuint8mf2_t qh0 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); + vuint8mf2_t qh1 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); + vuint8mf2_t qh2 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); + vuint8mf2_t qh3 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + + vuint8mf2_t q6h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_0, 0xF, vl), qh0, vl); + vuint8mf2_t q6h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_1, 0xF, vl), qh1, vl); + vuint8mf2_t q6h_2 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_0, 0x4, vl), qh2, vl); + vuint8mf2_t q6h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_1, 0x4, vl), qh3, vl); + + vint8mf2_t q6v_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_0), 32, vl); + vint8mf2_t q6v_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_1), 32, vl); + vint8mf2_t q6v_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_2), 32, vl); + vint8mf2_t q6v_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_3), 32, vl); + + // load Q8 and take product + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q6v_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q6v_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q6v_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q6v_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + + vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); + vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); + vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); + vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scale[0]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scale[1]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scale[2]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scale[3]; + + sumf += isum * d_all * y[i].d; + + } + + *s = sumf; + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int l = 0; l < 16; ++l) { + a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#endif diff --git a/applications/src/llama/ggml-quants.h b/applications/src/llama/ggml-quants.h new file mode 100644 index 000000000..70c12c274 --- /dev/null +++ b/applications/src/llama/ggml-quants.h @@ -0,0 +1,224 @@ +#pragma once + +#include "ggml-impl.h" + +// GGML internal header + +#include +#include + +#define QK4_0 32 +typedef struct { + ggml_fp16_t d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +typedef struct { + ggml_fp16_t d; // delta + ggml_fp16_t m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK5_0 32 +typedef struct { + ggml_fp16_t d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +typedef struct { + ggml_fp16_t d; // delta + ggml_fp16_t m; // min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + +#define QK8_0 32 +typedef struct { + ggml_fp16_t d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); + +#define QK8_1 32 +typedef struct { + float d; // delta + float s; // d * sum(qs[i]) + int8_t qs[QK8_1]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); + +// +// Super-block quantization structures +// + +// Super-block size +#ifdef GGML_QKK_64 +#define QK_K 64 +#define K_SCALE_SIZE 4 +#else +#define QK_K 256 +#define K_SCALE_SIZE 12 +#endif + +// 2-bit quantization +// weight is represented as x = a * q + b +// 16 blocks of 16 elements each +// Effectively 2.5625 bits per weight +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + ggml_fp16_t d; // super-block scale for quantized scales + ggml_fp16_t dmin; // super-block scale for quantized mins +} block_q2_K; +static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); + +// 3-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 3.4375 bits per weight +#ifdef GGML_QKK_64 +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[2]; + ggml_fp16_t d; // super-block scale +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding"); +#else +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[12]; // scales, quantized with 6 bits + ggml_fp16_t d; // super-block scale +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); +#endif + +// 4-bit quantization +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 4.5 bits per weight +#ifdef GGML_QKK_64 +typedef struct { + ggml_fp16_t d[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); +#else +typedef struct { + ggml_fp16_t d; // super-block scale for quantized scales + ggml_fp16_t dmin; // super-block scale for quantized mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); +#endif + +// 5-bit quantization +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 5.5 bits per weight +#ifdef GGML_QKK_64 +typedef struct { + ggml_fp16_t d; // super-block scale + int8_t scales[QK_K/16]; // 8-bit block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); +#else +typedef struct { + ggml_fp16_t d; // super-block scale for quantized scales + ggml_fp16_t dmin; // super-block scale for quantized mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif + +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + ggml_fp16_t d; // super-block scale +} block_q6_K; +static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding"); + +// This is only used for intermediate quantization and dot products +typedef struct { + float d; // delta + int8_t qs[QK_K]; // quants + int16_t bsums[QK_K/16]; // sum of quants in groups of 16 +} block_q8_K; +static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); + + +// Quantization +void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k); +void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k); +void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k); +void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k); +void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k); +void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k); + +void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k); +void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k); +void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k); +void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k); +void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k); +void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k); + +void quantize_row_q4_0(const float * restrict x, void * restrict y, int k); +void quantize_row_q4_1(const float * restrict x, void * restrict y, int k); +void quantize_row_q5_0(const float * restrict x, void * restrict y, int k); +void quantize_row_q5_1(const float * restrict x, void * restrict y, int k); +void quantize_row_q8_0(const float * restrict x, void * restrict y, int k); +void quantize_row_q8_1(const float * restrict x, void * restrict y, int k); + +void quantize_row_q2_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q3_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q4_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q5_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q6_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q8_K(const float * restrict x, void * restrict y, int k); + +// Dequantization +void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k); +void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k); +void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k); +void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k); +void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k); +//void dequantize_row_q8_1(const block_q8_1 * restrict x, float * restrict y, int k); + +void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k); +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k); +void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k); +void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k); +void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k); +void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k); + +// Dot product +void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); + +void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); diff --git a/applications/src/llama/ggml.c b/applications/src/llama/ggml.c new file mode 100644 index 000000000..ad546a731 --- /dev/null +++ b/applications/src/llama/ggml.c @@ -0,0 +1,19762 @@ +#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows +#define _USE_MATH_DEFINES // For M_PI on MSVC + +#include "ggml-impl.h" +#include "ggml-quants.h" + +#if defined(_MSC_VER) || defined(__MINGW32__) +#include // using malloc.h with MSC/MINGW +#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef GGML_USE_METAL +#include +#endif + +#if defined(_MSC_VER) +// disable "possible loss of data" to avoid hundreds of casts +// we should just be careful :) +#pragma warning(disable: 4244 4267) + +// disable POSIX deprecation warnings +// these functions are never going away, anyway +#pragma warning(disable: 4996) +#endif + +#if defined(_WIN32) + +#include + +typedef volatile LONG atomic_int; +typedef atomic_int atomic_bool; + +static void atomic_store(atomic_int * ptr, LONG val) { + InterlockedExchange(ptr, val); +} +static LONG atomic_load(atomic_int * ptr) { + return InterlockedCompareExchange(ptr, 0, 0); +} +static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) { + return InterlockedExchangeAdd(ptr, inc); +} +static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) { + return atomic_fetch_add(ptr, -(dec)); +} + +typedef HANDLE pthread_t; + +typedef DWORD thread_ret_t; +static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) { + (void) unused; + HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL); + if (handle == NULL) + { + return EAGAIN; + } + + *out = handle; + return 0; +} + +static int pthread_join(pthread_t thread, void * unused) { + (void) unused; + int ret = (int) WaitForSingleObject(thread, INFINITE); + CloseHandle(thread); + return ret; +} + +static int sched_yield (void) { + Sleep (0); + return 0; +} +#else +#include +#include + +typedef void * thread_ret_t; + +#include +#include +#include + +#endif + +#ifdef GGML_USE_CPU_HBM +#include +#endif + +#if defined(__APPLE__) +#include +#endif + +#if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \ + (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH)) + +#include + +void ggml_print_backtrace(void) { + /* + #include + #include + + void * trace[100]; + + int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0])); + + backtrace_symbols_fd(trace, nptrs, STDERR_FILENO); + */ + + // backtrack_symbols does not show line numbers, use gdb instead + char attach[32]; + snprintf(attach, sizeof(attach), "attach %d", getpid()); + int pid = fork(); + if (pid == 0) { + execlp("gdb", "gdb", "--batch", + "-ex", "set style enabled on", + "-ex", attach, + "-ex", "bt -frame-info source-and-location", + "-ex", "detach", + "-ex", "quit", + NULL); + } else { + waitpid(pid, NULL, 0); + } +} +#else +void ggml_print_backtrace(void) { + // platform not supported +} +#endif + +/*#define GGML_PERF*/ +#define GGML_DEBUG 0 +#define GGML_GELU_FP16 +#define GGML_GELU_QUICK_FP16 +#define GGML_SILU_FP16 +// #define GGML_CROSS_ENTROPY_EXP_FP16 +// #define GGML_FLASH_ATTN_EXP_FP16 + +#define GGML_SOFT_MAX_UNROLL 4 +#define GGML_VEC_DOT_UNROLL 2 +#define GGML_VEC_MAD_UNROLL 32 + +// +// logging +// + +#if (GGML_DEBUG >= 1) +#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG(...) +#endif + +#if (GGML_DEBUG >= 5) +#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_5(...) +#endif + +#if (GGML_DEBUG >= 10) +#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_10(...) +#endif + +#define GGML_PRINT(...) printf(__VA_ARGS__) + +// +// end of logging block +// + +#ifdef GGML_USE_ACCELERATE +// uncomment to use vDSP for soft max computation +// note: not sure if it is actually faster +//#define GGML_SOFT_MAX_ACCELERATE +#endif + +#if defined(_MSC_VER) || defined(__MINGW32__) +#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN) +#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr) +#else +inline static void * ggml_aligned_malloc(size_t size) { + if (size == 0) { + GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n"); + return NULL; + } + void * aligned_memory = NULL; +#ifdef GGML_USE_CPU_HBM + int result = hbw_posix_memalign(&aligned_memory, 16, size); +#elif GGML_USE_METAL + int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size); +#else + int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size); +#endif + if (result != 0) { + // Handle allocation failure + const char *error_desc = "unknown allocation error"; + switch (result) { + case EINVAL: + error_desc = "invalid alignment value"; + break; + case ENOMEM: + error_desc = "insufficient memory"; + break; + } + GGML_PRINT("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0)); + return NULL; + } + return aligned_memory; +} +#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size) +#ifdef GGML_USE_CPU_HBM +#define GGML_ALIGNED_FREE(ptr) if(NULL != ptr) hbw_free(ptr) +#else +#define GGML_ALIGNED_FREE(ptr) free(ptr) +#endif +#endif + +#define UNUSED GGML_UNUSED +#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) + +#if defined(GGML_USE_ACCELERATE) +#include +#if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions +#include "ggml-opencl.h" +#endif +#elif defined(GGML_USE_OPENBLAS) +#if defined(GGML_BLAS_USE_MKL) +#include +#else +#include +#endif +#elif defined(GGML_USE_CUBLAS) +#include "ggml-cuda.h" +#elif defined(GGML_USE_CLBLAST) +#include "ggml-opencl.h" +#endif + +// floating point type used to accumulate sums +typedef double ggml_float; + +#undef MIN +#undef MAX + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +// +// global data +// + +// precomputed gelu table for f16 (128 KB) +static ggml_fp16_t ggml_table_gelu_f16[1 << 16]; + +// precomputed quick gelu table for f16 (128 KB) +static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16]; + +// precomputed silu table for f16 (128 KB) +static ggml_fp16_t ggml_table_silu_f16[1 << 16]; + +// precomputed exp table for f16 (128 KB) +static ggml_fp16_t ggml_table_exp_f16[1 << 16]; + +// precomputed f32 table for f16 (256 KB) (ggml-impl.h) +float ggml_table_f32_f16[1 << 16]; + +// note: do not use these inside ggml.c +// these are meant to be used via the ggml.h API +float ggml_fp16_to_fp32(ggml_fp16_t x) { + return (float) GGML_FP16_TO_FP32(x); +} + +ggml_fp16_t ggml_fp32_to_fp16(float x) { + return GGML_FP32_TO_FP16(x); +} + +void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n) { + for (int i = 0; i < n; i++) { + y[i] = GGML_FP16_TO_FP32(x[i]); + } +} + +void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n) { + int i = 0; +#if defined(__F16C__) + for (; i + 7 < n; i += 8) { + __m256 x_vec = _mm256_loadu_ps(x + i); + __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128((__m128i *)(y + i), y_vec); + } + for(; i + 3 < n; i += 4) { + __m128 x_vec = _mm_loadu_ps(x + i); + __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storel_epi64((__m128i *)(y + i), y_vec); + } +#endif + for (; i < n; i++) { + y[i] = GGML_FP32_TO_FP16(x[i]); + } +} + +// +// timing +// + +#if defined(_MSC_VER) || defined(__MINGW32__) +static int64_t timer_freq, timer_start; +void ggml_time_init(void) { + LARGE_INTEGER t; + QueryPerformanceFrequency(&t); + timer_freq = t.QuadPart; + + // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq + // and the uptime is high enough. + // We subtract the program start time to reduce the likelihood of that happening. + QueryPerformanceCounter(&t); + timer_start = t.QuadPart; +} +int64_t ggml_time_ms(void) { + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return ((t.QuadPart-timer_start) * 1000) / timer_freq; +} +int64_t ggml_time_us(void) { + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return ((t.QuadPart-timer_start) * 1000000) / timer_freq; +} +#else +void ggml_time_init(void) {} +int64_t ggml_time_ms(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000; +} + +int64_t ggml_time_us(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000; +} +#endif + +int64_t ggml_cycles(void) { + return clock(); +} + +int64_t ggml_cycles_per_ms(void) { + return CLOCKS_PER_SEC/1000; +} + +#ifdef GGML_PERF +#define ggml_perf_time_ms() ggml_time_ms() +#define ggml_perf_time_us() ggml_time_us() +#define ggml_perf_cycles() ggml_cycles() +#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms() +#else +#define ggml_perf_time_ms() 0 +#define ggml_perf_time_us() 0 +#define ggml_perf_cycles() 0 +#define ggml_perf_cycles_per_ms() 0 +#endif + +// +// cache line +// + +#if defined(__cpp_lib_hardware_interference_size) +#define CACHE_LINE_SIZE hardware_destructive_interference_size +#else +#if defined(__POWER9_VECTOR__) +#define CACHE_LINE_SIZE 128 +#else +#define CACHE_LINE_SIZE 64 +#endif +#endif + +static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); + +static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y); +static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y); + +static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { + [GGML_TYPE_I8] = { + .type_name = "i8", + .blck_size = 1, + .type_size = sizeof(int8_t), + .is_quantized = false, + }, + [GGML_TYPE_I16] = { + .type_name = "i16", + .blck_size = 1, + .type_size = sizeof(int16_t), + .is_quantized = false, + }, + [GGML_TYPE_I32] = { + .type_name = "i32", + .blck_size = 1, + .type_size = sizeof(int32_t), + .is_quantized = false, + }, + [GGML_TYPE_F32] = { + .type_name = "f32", + .blck_size = 1, + .type_size = sizeof(float), + .is_quantized = false, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + }, + [GGML_TYPE_F16] = { + .type_name = "f16", + .blck_size = 1, + .type_size = sizeof(ggml_fp16_t), + .is_quantized = false, + .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, + .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row, + .from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, + .vec_dot_type = GGML_TYPE_F16, + }, + [GGML_TYPE_Q4_0] = { + .type_name = "q4_0", + .blck_size = QK4_0, + .type_size = sizeof(block_q4_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q4_0, + .from_float = quantize_row_q4_0, + .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference, + .vec_dot = ggml_vec_dot_q4_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + }, + [GGML_TYPE_Q4_1] = { + .type_name = "q4_1", + .blck_size = QK4_1, + .type_size = sizeof(block_q4_1), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q4_1, + .from_float = quantize_row_q4_1, + .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference, + .vec_dot = ggml_vec_dot_q4_1_q8_1, + .vec_dot_type = GGML_TYPE_Q8_1, + }, + [4] = { // GGML_TYPE_Q4_2 + .type_name = "DEPRECATED", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + .to_float = NULL, + .from_float = NULL, + .from_float_reference = NULL, + .vec_dot = NULL, + .vec_dot_type = GGML_TYPE_COUNT, + }, + [5] = { // GGML_TYPE_Q4_3 + .type_name = "DEPRECATED", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + .to_float = NULL, + .from_float = NULL, + .from_float_reference = NULL, + .vec_dot = NULL, + .vec_dot_type = GGML_TYPE_COUNT, + }, + [GGML_TYPE_Q5_0] = { + .type_name = "q5_0", + .blck_size = QK5_0, + .type_size = sizeof(block_q5_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q5_0, + .from_float = quantize_row_q5_0, + .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference, + .vec_dot = ggml_vec_dot_q5_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + }, + [GGML_TYPE_Q5_1] = { + .type_name = "q5_1", + .blck_size = QK5_1, + .type_size = sizeof(block_q5_1), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q5_1, + .from_float = quantize_row_q5_1, + .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference, + .vec_dot = ggml_vec_dot_q5_1_q8_1, + .vec_dot_type = GGML_TYPE_Q8_1, + }, + [GGML_TYPE_Q8_0] = { + .type_name = "q8_0", + .blck_size = QK8_0, + .type_size = sizeof(block_q8_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q8_0, + .from_float = quantize_row_q8_0, + .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference, + .vec_dot = ggml_vec_dot_q8_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + }, + [GGML_TYPE_Q8_1] = { + .type_name = "q8_1", + .blck_size = QK8_1, + .type_size = sizeof(block_q8_1), + .is_quantized = true, + .from_float = quantize_row_q8_1, + .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference, + .vec_dot_type = GGML_TYPE_Q8_1, + }, + [GGML_TYPE_Q2_K] = { + .type_name = "q2_K", + .blck_size = QK_K, + .type_size = sizeof(block_q2_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q2_K, + .from_float = quantize_row_q2_K, + .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference, + .vec_dot = ggml_vec_dot_q2_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + }, + [GGML_TYPE_Q3_K] = { + .type_name = "q3_K", + .blck_size = QK_K, + .type_size = sizeof(block_q3_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q3_K, + .from_float = quantize_row_q3_K, + .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference, + .vec_dot = ggml_vec_dot_q3_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + }, + [GGML_TYPE_Q4_K] = { + .type_name = "q4_K", + .blck_size = QK_K, + .type_size = sizeof(block_q4_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q4_K, + .from_float = quantize_row_q4_K, + .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference, + .vec_dot = ggml_vec_dot_q4_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + }, + [GGML_TYPE_Q5_K] = { + .type_name = "q5_K", + .blck_size = QK_K, + .type_size = sizeof(block_q5_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q5_K, + .from_float = quantize_row_q5_K, + .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference, + .vec_dot = ggml_vec_dot_q5_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + }, + [GGML_TYPE_Q6_K] = { + .type_name = "q6_K", + .blck_size = QK_K, + .type_size = sizeof(block_q6_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q6_K, + .from_float = quantize_row_q6_K, + .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference, + .vec_dot = ggml_vec_dot_q6_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + }, + [GGML_TYPE_Q8_K] = { + .type_name = "q8_K", + .blck_size = QK_K, + .type_size = sizeof(block_q8_K), + .is_quantized = true, + .from_float = quantize_row_q8_K, + } +}; + +// For internal test use +ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { + GGML_ASSERT(type < GGML_TYPE_COUNT); + return type_traits[type]; +} + +// +// simd mappings +// + +#if defined(__ARM_NEON) +#if !defined(__aarch64__) + +// 64-bit compatibility + +inline static float vaddvq_f32(float32x4_t v) { + return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); +} + +#endif +#endif + +// we define a common set of C macros which map to specific intrinsics based on the current architecture +// we then implement the fundamental computation operations below using only these macros +// adding support for new architectures requires to define the corresponding SIMD macros +// +// GGML_F32_STEP / GGML_F16_STEP +// number of elements to process in a single step +// +// GGML_F32_EPR / GGML_F16_EPR +// number of elements to fit in a single register +// + +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) + +#define GGML_SIMD + +// F32 NEON + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 float32x4_t +#define GGML_F32x4_ZERO vdupq_n_f32(0.0f) +#define GGML_F32x4_SET1(x) vdupq_n_f32(x) +#define GGML_F32x4_LOAD vld1q_f32 +#define GGML_F32x4_STORE vst1q_f32 +#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) +#define GGML_F32x4_ADD vaddq_f32 +#define GGML_F32x4_MUL vmulq_f32 +#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f32(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f32(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f32(x[i], x[offset+i]); \ + } \ + res = GGML_F32x4_REDUCE_ONE(x[0]); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + #define GGML_F16_STEP 32 + #define GGML_F16_EPR 8 + + #define GGML_F16x8 float16x8_t + #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) + #define GGML_F16x8_SET1(x) vdupq_n_f16(x) + #define GGML_F16x8_LOAD vld1q_f16 + #define GGML_F16x8_STORE vst1q_f16 + #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) + #define GGML_F16x8_ADD vaddq_f16 + #define GGML_F16x8_MUL vmulq_f16 + #define GGML_F16x8_REDUCE(res, x) \ + do { \ + int offset = GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f16(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f16(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f16(x[i], x[offset+i]); \ + } \ + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ + res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ + } while (0) + + #define GGML_F16_VEC GGML_F16x8 + #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO + #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i]) + #define GGML_F16_VEC_FMA GGML_F16x8_FMA + #define GGML_F16_VEC_ADD GGML_F16x8_ADD + #define GGML_F16_VEC_MUL GGML_F16x8_MUL + #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE +#else + // if FP16 vector arithmetic is not supported, we use FP32 instead + // and take advantage of the vcvt_ functions to convert to/from FP16 + + #define GGML_F16_STEP 16 + #define GGML_F16_EPR 4 + + #define GGML_F32Cx4 float32x4_t + #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) + #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) + #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x)) + #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) + #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) + #define GGML_F32Cx4_ADD vaddq_f32 + #define GGML_F32Cx4_MUL vmulq_f32 + #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + + #define GGML_F16_VEC GGML_F32Cx4 + #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO + #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) + #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA + #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD + #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL + #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE +#endif + +#elif defined(__AVX__) + +#define GGML_SIMD + +// F32 AVX + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 8 + +#define GGML_F32x8 __m256 +#define GGML_F32x8_ZERO _mm256_setzero_ps() +#define GGML_F32x8_SET1(x) _mm256_set1_ps(x) +#define GGML_F32x8_LOAD _mm256_loadu_ps +#define GGML_F32x8_STORE _mm256_storeu_ps +#if defined(__FMA__) + #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) +#else + #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) +#endif +#define GGML_F32x8_ADD _mm256_add_ps +#define GGML_F32x8_MUL _mm256_mul_ps +#define GGML_F32x8_REDUCE(res, x) \ +do { \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ + } \ + const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ + _mm256_extractf128_ps(x[0], 1)); \ + const __m128 t1 = _mm_hadd_ps(t0, t0); \ + res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ +} while (0) +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x8 +#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x8_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD +#define GGML_F32_VEC_STORE GGML_F32x8_STORE +#define GGML_F32_VEC_FMA GGML_F32x8_FMA +#define GGML_F32_VEC_ADD GGML_F32x8_ADD +#define GGML_F32_VEC_MUL GGML_F32x8_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE + +// F16 AVX + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 8 + +// F16 arithmetic is not supported by AVX, so we use F32 instead + +#define GGML_F32Cx8 __m256 +#define GGML_F32Cx8_ZERO _mm256_setzero_ps() +#define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x) + +#if defined(__F16C__) +// the _mm256_cvt intrinsics require F16C +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) +#else +static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { + float tmp[8]; + + for (int i = 0; i < 8; i++) { + tmp[i] = GGML_FP16_TO_FP32(x[i]); + } + + return _mm256_loadu_ps(tmp); +} +static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { + float arr[8]; + + _mm256_storeu_ps(arr, y); + + for (int i = 0; i < 8; i++) + x[i] = GGML_FP32_TO_FP16(arr[i]); +} +#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) +#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y) +#endif + +#define GGML_F32Cx8_FMA GGML_F32x8_FMA +#define GGML_F32Cx8_ADD _mm256_add_ps +#define GGML_F32Cx8_MUL _mm256_mul_ps +#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE + +#define GGML_F16_VEC GGML_F32Cx8 +#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE + +#elif defined(__POWER9_VECTOR__) + +#define GGML_SIMD + +// F32 POWER9 + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 vector float +#define GGML_F32x4_ZERO 0.0f +#define GGML_F32x4_SET1 vec_splats +#define GGML_F32x4_LOAD(p) vec_xl(0, p) +#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p) +#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) +#define GGML_F32x4_ADD vec_add +#define GGML_F32x4_MUL vec_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ + } \ + res = vec_extract(x[0], 0) + \ + vec_extract(x[0], 1) + \ + vec_extract(x[0], 2) + \ + vec_extract(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 POWER9 +#define GGML_F16_STEP GGML_F32_STEP +#define GGML_F16_EPR GGML_F32_EPR +#define GGML_F16_VEC GGML_F32x4 +#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F16_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F16_VEC_FMA GGML_F32x4_FMA +#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE +// Use vec_xl, not vec_ld, in case the load address is not aligned. +#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \ + vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \ + vec_extract_fp32_from_shortl(vec_xl(0, p)) +#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] +#define GGML_F16_VEC_STORE(p, r, i) \ + if (i & 0x1) \ + vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \ + r[i - GGML_ENDIAN_BYTE(0)]), \ + 0, p - GGML_F16_EPR) + +#elif defined(__wasm_simd128__) + +#define GGML_SIMD + +// F32 WASM + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 v128_t +#define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F32x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F32x4_LOAD wasm_v128_load +#define GGML_F32x4_STORE wasm_v128_store +#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) +#define GGML_F32x4_ADD wasm_f32x4_add +#define GGML_F32x4_MUL wasm_f32x4_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 WASM + +#define GGML_F16_STEP 16 +#define GGML_F16_EPR 4 + +inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) { + float tmp[4]; + + tmp[0] = GGML_FP16_TO_FP32(p[0]); + tmp[1] = GGML_FP16_TO_FP32(p[1]); + tmp[2] = GGML_FP16_TO_FP32(p[2]); + tmp[3] = GGML_FP16_TO_FP32(p[3]); + + return wasm_v128_load(tmp); +} + +inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) { + float tmp[4]; + + wasm_v128_store(tmp, x); + + p[0] = GGML_FP32_TO_FP16(tmp[0]); + p[1] = GGML_FP32_TO_FP16(tmp[1]); + p[2] = GGML_FP32_TO_FP16(tmp[2]); + p[3] = GGML_FP32_TO_FP16(tmp[3]); +} + +#define GGML_F16x4 v128_t +#define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F16x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x) +#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) +#define GGML_F16x4_FMA GGML_F32x4_FMA +#define GGML_F16x4_ADD wasm_f32x4_add +#define GGML_F16x4_MUL wasm_f32x4_mul +#define GGML_F16x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F16_VEC GGML_F16x4 +#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO +#define GGML_F16_VEC_SET1 GGML_F16x4_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F16x4_FMA +#define GGML_F16_VEC_ADD GGML_F16x4_ADD +#define GGML_F16_VEC_MUL GGML_F16x4_MUL +#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE + +#elif defined(__SSE3__) + +#define GGML_SIMD + +// F32 SSE + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 __m128 +#define GGML_F32x4_ZERO _mm_setzero_ps() +#define GGML_F32x4_SET1(x) _mm_set1_ps(x) +#define GGML_F32x4_LOAD _mm_loadu_ps +#define GGML_F32x4_STORE _mm_storeu_ps +#if defined(__FMA__) + // TODO: Does this work? + #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a) +#else + #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a) +#endif +#define GGML_F32x4_ADD _mm_add_ps +#define GGML_F32x4_MUL _mm_mul_ps +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ + } \ + const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \ + res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \ +} +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 SSE + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 4 + +static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) { + float tmp[4]; + + tmp[0] = GGML_FP16_TO_FP32(x[0]); + tmp[1] = GGML_FP16_TO_FP32(x[1]); + tmp[2] = GGML_FP16_TO_FP32(x[2]); + tmp[3] = GGML_FP16_TO_FP32(x[3]); + + return _mm_loadu_ps(tmp); +} + +static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) { + float arr[4]; + + _mm_storeu_ps(arr, y); + + x[0] = GGML_FP32_TO_FP16(arr[0]); + x[1] = GGML_FP32_TO_FP16(arr[1]); + x[2] = GGML_FP32_TO_FP16(arr[2]); + x[3] = GGML_FP32_TO_FP16(arr[3]); +} + +#define GGML_F32Cx4 __m128 +#define GGML_F32Cx4_ZERO _mm_setzero_ps() +#define GGML_F32Cx4_SET1(x) _mm_set1_ps(x) +#define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x) +#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y) +#define GGML_F32Cx4_FMA GGML_F32x4_FMA +#define GGML_F32Cx4_ADD _mm_add_ps +#define GGML_F32Cx4_MUL _mm_mul_ps +#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + +#define GGML_F16_VEC GGML_F32Cx4 +#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE + +#endif + +// GGML_F32_ARR / GGML_F16_ARR +// number of registers to use per step +#ifdef GGML_SIMD +#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR) +#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) +#endif + +// +// fundamental operations +// + +inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } +inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } +inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } +inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } +inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } +inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } +inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } + +static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { +#ifdef GGML_SIMD + float sumf = 0.0f; + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += x[i]*y[i]; + } +#else + // scalar + ggml_float sumf = 0.0; + for (int i = 0; i < n; ++i) { + sumf += (ggml_float)(x[i]*y[i]); + } +#endif + + *s = sumf; +} + +static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { + ggml_float sumf = 0.0; + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + + sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F16_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); + } +#else + for (int i = 0; i < n; ++i) { + sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); + } +#endif + + *s = sumf; +} + +// compute GGML_VEC_DOT_UNROLL dot products at once +// xs - x row stride in bytes +inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { + ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; + + ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL]; + + for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { + x[i] = (ggml_fp16_t *) ((char *) xv + i*xs); + } + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); + + sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); + } + } + } + + // reduce sum0..sum3 to sum0 + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + GGML_F16_VEC_REDUCE(sumf[k], sum[k]); + } + + // leftovers + for (int i = np; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); + } + } +#else + for (int i = 0; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); + } + } +#endif + + for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { + s[i] = sumf[i]; + } +} + +inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += x[i]*v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] += x[i]*v; + } +#endif +} + +// xs and vs are byte strides of x and v +inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { + + const float * restrict x[GGML_VEC_MAD_UNROLL]; + const float * restrict v[GGML_VEC_MAD_UNROLL]; + + for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) { + x[i] = (const float *) ((const char *) xv + i*xs); + v[i] = (const float *) ((const char *) vv + i*vs); + } + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL]; + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + vx[k] = GGML_F32_VEC_SET1(v[k][0]); + } + + GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); + } + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = np; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } +#else + // scalar + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = 0; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } +#endif +} + +//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } +inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { +#if defined(GGML_USE_ACCELERATE) + vDSP_vsmul(y, 1, &v, y, 1, n); +#elif defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_MUL(ay[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] *= v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] *= v; + } +#endif +} + +inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } +inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } +inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } +inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } +inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } +inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } +inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } +inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } +inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; } +inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } +inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); } + +static const float GELU_COEF_A = 0.044715f; +static const float GELU_QUICK_COEF = -1.702f; +static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +inline static float ggml_gelu_f32(float x) { + return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = ggml_table_gelu_f16[i16[i]]; + } +} + +#ifdef GGML_GELU_FP16 +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]); + } +} +#else +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_f32(x[i]); + } +} +#endif + +inline static float ggml_gelu_quick_f32(float x) { + return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); +} + +//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { +// const uint16_t * i16 = (const uint16_t *) x; +// for (int i = 0; i < n; ++i) { +// y[i] = ggml_table_gelu_quick_f16[i16[i]]; +// } +//} + +#ifdef GGML_GELU_QUICK_FP16 +inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]); + } +} +#else +inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_quick_f32(x[i]); + } +} +#endif + +// Sigmoid Linear Unit (SiLU) function +inline static float ggml_silu_f32(float x) { + return x/(1.0f + expf(-x)); +} + +//inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { +// const uint16_t * i16 = (const uint16_t *) x; +// for (int i = 0; i < n; ++i) { +// y[i] = ggml_table_silu_f16[i16[i]]; +// } +//} + +#ifdef GGML_SILU_FP16 +inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(ggml_table_silu_f16[t]); + } +} +#else +inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_silu_f32(x[i]); + } +} +#endif + +inline static float ggml_silu_backward_f32(float x, float dy) { + const float s = 1.0f/(1.0f + expf(-x)); + return dy*s*(1.0f + x*(1.0f - s)); +} + +#ifdef GGML_SILU_FP16 +inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) { + for (int i = 0; i < n; ++i) { + // we did not use x[i] to compute forward silu but its f16 equivalent + // take derivative at f16 of x[i]: + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + float usedx = GGML_FP16_TO_FP32(fp16); + dx[i] = ggml_silu_backward_f32(usedx, dy[i]); + } +} +#else +inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) { + for (int i = 0; i < n; ++i) { + dx[i] = ggml_silu_backward_f32(x[i], dy[i]); + } +} +#endif + +inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { +#ifndef GGML_USE_ACCELERATE + ggml_float sum = 0.0; + for (int i = 0; i < n; ++i) { + sum += (ggml_float)x[i]; + } + *s = sum; +#else + vDSP_sve(x, 1, s, n); +#endif +} + +inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) { + ggml_float sum = 0.0; + for (int i = 0; i < n; ++i) { + sum += (ggml_float)x[i]; + } + *s = sum; +} + +inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) { + float sum = 0.0f; + for (int i = 0; i < n; ++i) { + sum += GGML_FP16_TO_FP32(x[i]); + } + *s = sum; +} + +inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { +#ifndef GGML_USE_ACCELERATE + float max = -INFINITY; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + } + *s = max; +#else + vDSP_maxv(x, 1, s, n); +#endif +} + +inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { + ggml_vec_norm_f32(n, s, x); + *s = 1.f/(*s); +} + +inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) { + float max = -INFINITY; + int idx = 0; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + if (max == x[i]) { idx = i; } + } + *s = idx; +} + +// +// data types +// + +static const char * GGML_OP_NAME[GGML_OP_COUNT] = { + "NONE", + + "DUP", + "ADD", + "ADD1", + "ACC", + "SUB", + "MUL", + "DIV", + "SQR", + "SQRT", + "LOG", + "SUM", + "SUM_ROWS", + "MEAN", + "ARGMAX", + "REPEAT", + "REPEAT_BACK", + "CONCAT", + "SILU_BACK", + "NORM", + "RMS_NORM", + "RMS_NORM_BACK", + "GROUP_NORM", + + "MUL_MAT", + "MUL_MAT_ID", + "OUT_PROD", + + "SCALE", + "SET", + "CPY", + "CONT", + "RESHAPE", + "VIEW", + "PERMUTE", + "TRANSPOSE", + "GET_ROWS", + "GET_ROWS_BACK", + "DIAG", + "DIAG_MASK_INF", + "DIAG_MASK_ZERO", + "SOFT_MAX", + "SOFT_MAX_BACK", + "ROPE", + "ROPE_BACK", + "ALIBI", + "CLAMP", + "CONV_TRANSPOSE_1D", + "IM2COL", + "CONV_TRANSPOSE_2D", + "POOL_1D", + "POOL_2D", + "UPSCALE", + "PAD", + "ARGSORT", + "LEAKY_RELU", + + "FLASH_ATTN", + "FLASH_FF", + "FLASH_ATTN_BACK", + "WIN_PART", + "WIN_UNPART", + "GET_REL_POS", + "ADD_REL_POS", + + "UNARY", + + "MAP_UNARY", + "MAP_BINARY", + + "MAP_CUSTOM1_F32", + "MAP_CUSTOM2_F32", + "MAP_CUSTOM3_F32", + + "MAP_CUSTOM1", + "MAP_CUSTOM2", + "MAP_CUSTOM3", + + "CROSS_ENTROPY_LOSS", + "CROSS_ENTROPY_LOSS_BACK", +}; + +static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); + +static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { + "none", + + "x", + "x+y", + "x+y", + "view(x,nb,offset)+=y->x", + "x-y", + "x*y", + "x/y", + "x^2", + "√x", + "log(x)", + "Σx", + "Σx_k", + "Σx/n", + "argmax(x)", + "repeat(x)", + "repeat_back(x)", + "concat(x, y)", + "silu_back(x)", + "norm(x)", + "rms_norm(x)", + "rms_norm_back(x)", + "group_norm(x)", + + "X*Y", + "X[i]*Y", + "X*Y", + + "x*v", + "y-\\>view(x)", + "x-\\>y", + "cont(x)", + "reshape(x)", + "view(x)", + "permute(x)", + "transpose(x)", + "get_rows(x)", + "get_rows_back(x)", + "diag(x)", + "diag_mask_inf(x)", + "diag_mask_zero(x)", + "soft_max(x)", + "soft_max_back(x)", + "rope(x)", + "rope_back(x)", + "alibi(x)", + "clamp(x)", + "conv_transpose_1d(x)", + "im2col(x)", + "conv_transpose_2d(x)", + "pool_1d(x)", + "pool_2d(x)", + "upscale(x)", + "pad(x)", + "argsort(x)", + "leaky_relu(x)", + + "flash_attn(x)", + "flash_ff(x)", + "flash_attn_back(x)", + "win_part(x)", + "win_unpart(x)", + "get_rel_pos(x)", + "add_rel_pos(x)", + + "unary(x)", + + "f(x)", + "f(x,y)", + + "custom_f32(x)", + "custom_f32(x,y)", + "custom_f32(x,y,z)", + + "custom(x)", + "custom(x,y)", + "custom(x,y,z)", + + "cross_entropy_loss(x,y)", + "cross_entropy_loss_back(x,y)", +}; + +static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); + +static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); + + +static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { + "ABS", + "SGN", + "NEG", + "STEP", + "TANH", + "ELU", + "RELU", + "GELU", + "GELU_QUICK", + "SILU", +}; + +static_assert(GGML_UNARY_OP_COUNT == 10, "GGML_UNARY_OP_COUNT != 10"); + + +static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); +static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); + +// WARN: +// Mis-configuration can lead to problem that's hard to reason about: +// * At best it crash or talks nosense. +// * At worst it talks slightly difference but hard to perceive. +// +// An op has to enable INIT or FINALIZE when any of it's branch needs that pass. +// Take care about compile options (e.g., GGML_USE_xxx). +static bool GGML_OP_HAS_INIT [GGML_OP_COUNT] = { 0 }; +static bool GGML_OP_HAS_FINALIZE[GGML_OP_COUNT] = { 0 }; + +static void ggml_setup_op_has_task_pass(void) { + { // INIT + bool * p = GGML_OP_HAS_INIT; + + p[GGML_OP_ACC ] = true; + p[GGML_OP_MUL_MAT ] = true; + p[GGML_OP_MUL_MAT_ID ] = true; + p[GGML_OP_OUT_PROD ] = true; + p[GGML_OP_SET ] = true; + p[GGML_OP_GET_ROWS_BACK ] = true; + p[GGML_OP_DIAG_MASK_INF ] = true; + p[GGML_OP_DIAG_MASK_ZERO ] = true; + p[GGML_OP_CONV_TRANSPOSE_1D ] = true; + p[GGML_OP_CONV_TRANSPOSE_2D ] = true; + p[GGML_OP_FLASH_ATTN_BACK ] = true; + p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; + p[GGML_OP_ADD_REL_POS ] = true; + } + + { // FINALIZE + bool * p = GGML_OP_HAS_FINALIZE; + + p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; + } +} + +// +// ggml context +// + +struct ggml_context { + size_t mem_size; + void * mem_buffer; + bool mem_buffer_owned; + bool no_alloc; + bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers + + int n_objects; + + struct ggml_object * objects_begin; + struct ggml_object * objects_end; + + struct ggml_scratch scratch; + struct ggml_scratch scratch_save; +}; + +struct ggml_context_container { + bool used; + + struct ggml_context context; +}; + +// +// NUMA support +// + +#define GGML_NUMA_MAX_NODES 8 +#define GGML_NUMA_MAX_CPUS 512 + +struct ggml_numa_node { + uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node + uint32_t n_cpus; +}; + +struct ggml_numa_nodes { + struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES]; + uint32_t n_nodes; + uint32_t total_cpus; // hardware threads on system +}; + +// +// ggml state +// + +struct ggml_state { + struct ggml_context_container contexts[GGML_MAX_CONTEXTS]; + struct ggml_numa_nodes numa; +}; + +// global state +static struct ggml_state g_state; +static atomic_int g_state_barrier = 0; + +// barrier via spin lock +inline static void ggml_critical_section_start(void) { + int processing = atomic_fetch_add(&g_state_barrier, 1); + + while (processing > 0) { + // wait for other threads to finish + atomic_fetch_sub(&g_state_barrier, 1); + sched_yield(); // TODO: reconsider this + processing = atomic_fetch_add(&g_state_barrier, 1); + } +} + +// TODO: make this somehow automatically executed +// some sort of "sentry" mechanism +inline static void ggml_critical_section_end(void) { + atomic_fetch_sub(&g_state_barrier, 1); +} + +void ggml_numa_init(void) { + if (g_state.numa.n_nodes > 0) { + fprintf(stderr, "ggml_numa_init: NUMA already initialized\n"); + + return; + } + +#ifdef __linux__ + struct stat st; + char path[256]; + int rv; + + // enumerate nodes + while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes); + GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) != 0) { break; } + ++g_state.numa.n_nodes; + } + + // enumerate CPUs + while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus); + GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) != 0) { break; } + ++g_state.numa.total_cpus; + } + + GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus); + + if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1) { + g_state.numa.n_nodes = 0; + return; + } + + for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) { + struct ggml_numa_node * node = &g_state.numa.nodes[n]; + GGML_PRINT_DEBUG("CPUs on node %u:", n); + node->n_cpus = 0; + for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c); + GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) == 0) { + node->cpus[node->n_cpus++] = c; + GGML_PRINT_DEBUG(" %u", c); + } + } + GGML_PRINT_DEBUG("\n"); + } + + if (ggml_is_numa()) { + FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r"); + if (fptr != NULL) { + char buf[42]; + if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) { + GGML_PRINT("WARNING: /proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n"); + } + fclose(fptr); + } + } +#else + // TODO +#endif +} + +bool ggml_is_numa(void) { + return g_state.numa.n_nodes > 1; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ggml_print_object(const struct ggml_object * obj) { + GGML_PRINT(" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n", + obj->type, obj->offs, obj->size, (const void *) obj->next); +} + +void ggml_print_objects(const struct ggml_context * ctx) { + struct ggml_object * obj = ctx->objects_begin; + + GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx); + + while (obj != NULL) { + ggml_print_object(obj); + obj = obj->next; + } + + GGML_PRINT("%s: --- end ---\n", __func__); +} + +int64_t ggml_nelements(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; +} + +int64_t ggml_nrows(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; +} + +size_t ggml_nbytes(const struct ggml_tensor * tensor) { + size_t nbytes; + size_t blck_size = ggml_blck_size(tensor->type); + if (blck_size == 1) { + nbytes = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; + } + } + else { + nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; + } + } + + return nbytes; +} + +size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) { + return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN); +} + +int ggml_blck_size(enum ggml_type type) { + return type_traits[type].blck_size; +} + +size_t ggml_type_size(enum ggml_type type) { + return type_traits[type].type_size; +} + +size_t ggml_row_size(enum ggml_type type, int64_t ne) { + assert(ne % ggml_blck_size(type) == 0); + return ggml_type_size(type)*ne/ggml_blck_size(type); +} + +double ggml_type_sizef(enum ggml_type type) { + return ((double)(type_traits[type].type_size))/type_traits[type].blck_size; +} + +const char * ggml_type_name(enum ggml_type type) { + return type_traits[type].type_name; +} + +bool ggml_is_quantized(enum ggml_type type) { + return type_traits[type].is_quantized; +} + +const char * ggml_op_name(enum ggml_op op) { + return GGML_OP_NAME[op]; +} + +const char * ggml_op_symbol(enum ggml_op op) { + return GGML_OP_SYMBOL[op]; +} + +const char * ggml_unary_op_name(enum ggml_unary_op op) { + return GGML_UNARY_OP_NAME[op]; +} + +const char * ggml_op_desc(const struct ggml_tensor * t) { + if (t->op == GGML_OP_UNARY) { + enum ggml_unary_op uop = ggml_get_unary_op(t); + return ggml_unary_op_name(uop); + } + else { + return ggml_op_name(t->op); + } +} + +size_t ggml_element_size(const struct ggml_tensor * tensor) { + return ggml_type_size(tensor->type); +} + +bool ggml_is_scalar(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +bool ggml_is_vector(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +bool ggml_is_matrix(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +bool ggml_is_3d(const struct ggml_tensor * tensor) { + return tensor->ne[3] == 1; +} + +int ggml_n_dims(const struct ggml_tensor * tensor) { + for (int i = GGML_MAX_DIMS - 1; i >= 1; --i) { + if (tensor->ne[i] > 1) { + return i + 1; + } + } + return 1; +} + +static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[0] == t1->ne[0]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); +} + +static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[1] == t1->ne[1]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); +} + +enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { + enum ggml_type wtype = GGML_TYPE_COUNT; + + switch (ftype) { + case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break; + case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break; + case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; + case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; + case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; + case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; + case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; + case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; + case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; + case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; + case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break; + case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; + case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; + case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; + } + + GGML_ASSERT(wtype != GGML_TYPE_COUNT); + + return wtype; +} + +size_t ggml_tensor_overhead(void) { + return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE; +} + +bool ggml_is_transposed(const struct ggml_tensor * tensor) { + return tensor->nb[0] > tensor->nb[1]; +} + +bool ggml_is_contiguous(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +bool ggml_is_permuted(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3]; +} + +static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->ne[0] == t1->ne[0] ) && + (t0->ne[1] == t1->ne[1] ) && + (t0->ne[2] == t1->ne[2] ) && + (t0->ne[3] == t1->ne[3] ); +} + +// check if t1 can be represented as a repeatition of t0 +static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t1->ne[0]%t0->ne[0] == 0) && + (t1->ne[1]%t0->ne[1] == 0) && + (t1->ne[2]%t0->ne[2] == 0) && + (t1->ne[3]%t0->ne[3] == 0); +} + +static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1); +} + +static inline int ggml_up32(int n) { + return (n + 31) & ~31; +} + +//static inline int ggml_up64(int n) { +// return (n + 63) & ~63; +//} + +static inline int ggml_up(int n, int m) { + // assert m is a power of 2 + GGML_ASSERT((m & (m - 1)) == 0); + return (n + m - 1) & ~(m - 1); +} + +// assert that pointer is aligned to GGML_MEM_ALIGN +#define ggml_assert_aligned(ptr) \ + GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0) + +//////////////////////////////////////////////////////////////////////////////// + +struct ggml_context * ggml_init(struct ggml_init_params params) { + // make this function thread safe + ggml_critical_section_start(); + + static bool is_first_call = true; + + if (is_first_call) { + // initialize time system (required on Windows) + ggml_time_init(); + + // initialize GELU, Quick GELU, SILU and EXP F32 tables + { + const uint64_t t_start = ggml_time_us(); UNUSED(t_start); + + ggml_fp16_t ii; + for (int i = 0; i < (1 << 16); ++i) { + uint16_t ui = i; + memcpy(&ii, &ui, sizeof(ii)); + const float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); + ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); + ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f)); + ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f)); + ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f)); + } + + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); + + GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + } + + // initialize g_state + { + const uint64_t t_start = ggml_time_us(); UNUSED(t_start); + + g_state = (struct ggml_state) { + /*.contexts =*/ { { 0 } }, + /*.numa =*/ { + .n_nodes = 0, + .total_cpus = 0, + }, + }; + + for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) { + g_state.contexts[i].used = false; + } + + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); + + GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + } + +#if defined(GGML_USE_CUBLAS) + ggml_init_cublas(); +#elif defined(GGML_USE_CLBLAST) + ggml_cl_init(); +#endif + + ggml_setup_op_has_task_pass(); + + is_first_call = false; + } + + // find non-used context in g_state + struct ggml_context * ctx = NULL; + + for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { + if (!g_state.contexts[i].used) { + g_state.contexts[i].used = true; + ctx = &g_state.contexts[i].context; + + GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i); + break; + } + } + + if (ctx == NULL) { + GGML_PRINT_DEBUG("%s: no unused context found\n", __func__); + + ggml_critical_section_end(); + + return NULL; + } + + // allow to call ggml_init with 0 size + if (params.mem_size == 0) { + params.mem_size = GGML_MEM_ALIGN; + } + + const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN); + + *ctx = (struct ggml_context) { + /*.mem_size =*/ mem_size, + /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size), + /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, + /*.no_alloc =*/ params.no_alloc, + /*.no_alloc_save =*/ params.no_alloc, + /*.n_objects =*/ 0, + /*.objects_begin =*/ NULL, + /*.objects_end =*/ NULL, + /*.scratch =*/ { 0, 0, NULL, }, + /*.scratch_save =*/ { 0, 0, NULL, }, + }; + + GGML_ASSERT(ctx->mem_buffer != NULL); + + ggml_assert_aligned(ctx->mem_buffer); + + GGML_PRINT_DEBUG("%s: context initialized\n", __func__); + + ggml_critical_section_end(); + + return ctx; +} + +void ggml_free(struct ggml_context * ctx) { + // make this function thread safe + ggml_critical_section_start(); + + bool found = false; + + for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { + if (&g_state.contexts[i].context == ctx) { + g_state.contexts[i].used = false; + + GGML_PRINT_DEBUG("%s: context %d has been freed. memory used = %zu\n", + __func__, i, ggml_used_mem(ctx)); + + if (ctx->mem_buffer_owned) { + GGML_ALIGNED_FREE(ctx->mem_buffer); + } + + found = true; + break; + } + } + + if (!found) { + GGML_PRINT_DEBUG("%s: context not found\n", __func__); + } + + ggml_critical_section_end(); +} + +size_t ggml_used_mem(const struct ggml_context * ctx) { + return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size; +} + +size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) { + const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0; + + ctx->scratch = scratch; + + return result; +} + +bool ggml_get_no_alloc(struct ggml_context * ctx) { + return ctx->no_alloc; +} + +void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) { + ctx->no_alloc = no_alloc; +} + +void * ggml_get_mem_buffer(const struct ggml_context * ctx) { + return ctx->mem_buffer; +} + +size_t ggml_get_mem_size(const struct ggml_context * ctx) { + return ctx->mem_size; +} + +size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) { + size_t max_size = 0; + + struct ggml_object * obj = ctx->objects_begin; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TENSOR) { + struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs); + + const size_t size = ggml_nbytes(tensor); + + if (max_size < size) { + max_size = size; + } + } + + obj = obj->next; + } + + return max_size; +} + +// IMPORTANT: +// when creating "opt" tensors, always save and load the scratch buffer +// this is an error prone process, but it is necessary to support inplace +// operators when using scratch buffers +// TODO: implement a better way +static void ggml_scratch_save(struct ggml_context * ctx) { + // this is needed to allow opt tensors to store their data + // TODO: again, need to find a better way + ctx->no_alloc_save = ctx->no_alloc; + ctx->no_alloc = false; + + ctx->scratch_save = ctx->scratch; + ctx->scratch.data = NULL; +} + +static void ggml_scratch_load(struct ggml_context * ctx) { + ctx->no_alloc = ctx->no_alloc_save; + + ctx->scratch = ctx->scratch_save; +} + +//////////////////////////////////////////////////////////////////////////////// + +static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) { + // always insert objects at the end of the context's memory pool + struct ggml_object * obj_cur = ctx->objects_end; + + const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs; + const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; + const size_t cur_end = cur_offs + cur_size; + + // align to GGML_MEM_ALIGN + size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN); + + char * const mem_buffer = ctx->mem_buffer; + struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); + + if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { + GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", + __func__, cur_end + size_needed, ctx->mem_size); + assert(false); + return NULL; + } + + *obj_new = (struct ggml_object) { + .offs = cur_end + GGML_OBJECT_SIZE, + .size = size_needed, + .next = NULL, + .type = type, + }; + + ggml_assert_aligned(mem_buffer + obj_new->offs); + + if (obj_cur != NULL) { + obj_cur->next = obj_new; + } else { + // this is the first object in this context + ctx->objects_begin = obj_new; + } + + ctx->objects_end = obj_new; + + //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size); + + return obj_new; +} + +static struct ggml_tensor * ggml_new_tensor_impl( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int64_t * ne, + struct ggml_tensor * view_src, + size_t view_offs) { + + assert(n_dims >= 1 && n_dims <= GGML_MAX_DIMS); + + // find the base tensor and absolute offset + if (view_src != NULL && view_src->view_src != NULL) { + view_offs += view_src->view_offs; + view_src = view_src->view_src; + } + + size_t data_size = ggml_row_size(type, ne[0]); + for (int i = 1; i < n_dims; i++) { + data_size *= ne[i]; + } + + GGML_ASSERT(view_src == NULL || data_size + view_offs <= ggml_nbytes(view_src)); + + void * data = view_src != NULL ? view_src->data : NULL; + if (data != NULL) { + data = (char *) data + view_offs; + } + + size_t obj_alloc_size = 0; + + if (view_src == NULL && !ctx->no_alloc) { + if (ctx->scratch.data != NULL) { + // allocate tensor data in the scratch buffer + if (ctx->scratch.offs + data_size > ctx->scratch.size) { + GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n", + __func__, ctx->scratch.offs + data_size, ctx->scratch.size); + assert(false); + return NULL; + } + + data = (char * const) ctx->scratch.data + ctx->scratch.offs; + + ctx->scratch.offs += data_size; + } else { + // allocate tensor data in the context's memory pool + obj_alloc_size = data_size; + } + } + + struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size); + + // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here + + struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs); + + *result = (struct ggml_tensor) { + /*.type =*/ type, + /*.backend =*/ GGML_BACKEND_CPU, + /*.buffer =*/ NULL, + /*.ne =*/ { 1, 1, 1, 1 }, + /*.nb =*/ { 0, 0, 0, 0 }, + /*.op =*/ GGML_OP_NONE, + /*.op_params =*/ { 0 }, + /*.is_param =*/ false, + /*.grad =*/ NULL, + /*.src =*/ { NULL }, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + /*.view_src =*/ view_src, + /*.view_offs =*/ view_offs, + /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data, + /*.name =*/ { 0 }, + /*.extra =*/ NULL, + /*.padding =*/ { 0 }, + }; + + // TODO: this should not be needed as long as we don't rely on aligned SIMD loads + //ggml_assert_aligned(result->data); + + for (int i = 0; i < n_dims; i++) { + result->ne[i] = ne[i]; + } + + result->nb[0] = ggml_type_size(type); + result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type)); + for (int i = 2; i < GGML_MAX_DIMS; i++) { + result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; + } + + ctx->n_objects++; + + return result; +} + +struct ggml_tensor * ggml_new_tensor( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int64_t * ne) { + return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL, 0); +} + +struct ggml_tensor * ggml_new_tensor_1d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0) { + return ggml_new_tensor(ctx, type, 1, &ne0); +} + +struct ggml_tensor * ggml_new_tensor_2d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1) { + const int64_t ne[2] = { ne0, ne1 }; + return ggml_new_tensor(ctx, type, 2, ne); +} + +struct ggml_tensor * ggml_new_tensor_3d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + const int64_t ne[3] = { ne0, ne1, ne2 }; + return ggml_new_tensor(ctx, type, 3, ne); +} + +struct ggml_tensor * ggml_new_tensor_4d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + return ggml_new_tensor(ctx, type, 4, ne); +} + +struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { + ggml_scratch_save(ctx); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + + ggml_scratch_load(ctx); + + ggml_set_i32(result, value); + + return result; +} + +struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) { + ggml_scratch_save(ctx); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + + ggml_scratch_load(ctx); + + ggml_set_f32(result, value); + + return result; +} + +struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) { + return ggml_new_tensor(ctx, src->type, GGML_MAX_DIMS, src->ne); +} + +static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) { + GGML_ASSERT(tensor != NULL); // silence -Warray-bounds warnings + assert(params_size <= GGML_MAX_OP_PARAMS); + memcpy(tensor->op_params, params, params_size); +} + +static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) { + assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t)); + return ((const int32_t *)(tensor->op_params))[i]; +} + +static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) { + assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t)); + ((int32_t *)(tensor->op_params))[i] = value; +} + +struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { + memset(tensor->data, 0, ggml_nbytes(tensor)); + return tensor; +} + +struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { + const int n = ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value)); + } + } break; + case GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + default: + { + GGML_ASSERT(false); + } break; + } + + return tensor; +} + +struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { + const int n = ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value)); + } + } break; + case GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + default: + { + GGML_ASSERT(false); + } break; + } + + return tensor; +} + +void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) { + const int64_t ne2 = tensor->ne[2]; + const int64_t ne1 = tensor->ne[1]; + const int64_t ne0 = tensor->ne[0]; + + const int64_t i3_ = (i/(ne2*ne1*ne0)); + const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0); + const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0; + const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0); + + if (i0) { + * i0 = i0_; + } + if (i1) { + * i1 = i1_; + } + if (i2) { + * i2 = i2_; + } + if (i3) { + * i3 = i3_; + } +} + +int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]); + } + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + return ((int8_t *)(tensor->data))[i]; + } + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + return ((int16_t *)(tensor->data))[i]; + } + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + return ((int32_t *)(tensor->data))[i]; + } + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); + } + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return ((float *)(tensor->data))[i]; + } + default: + { + GGML_ASSERT(false); + } + } + + return 0.0f; +} + +void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value); + return; + } + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case GGML_TYPE_I8: + return ((int8_t *) data)[0]; + case GGML_TYPE_I16: + return ((int16_t *) data)[0]; + case GGML_TYPE_I32: + return ((int32_t *) data)[0]; + case GGML_TYPE_F16: + return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); + case GGML_TYPE_F32: + return ((float *) data)[0]; + default: + GGML_ASSERT(false); + } + + return 0.0f; +} + +void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case GGML_TYPE_I8: + { + ((int8_t *)(data))[0] = value; + } break; + case GGML_TYPE_I16: + { + ((int16_t *)(data))[0] = value; + } break; + case GGML_TYPE_I32: + { + ((int32_t *)(data))[0] = value; + } break; + case GGML_TYPE_F16: + { + ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + ((float *)(data))[0] = value; + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]); + } + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + return ((int8_t *)(tensor->data))[i]; + } + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + return ((int16_t *)(tensor->data))[i]; + } + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + return ((int32_t *)(tensor->data))[i]; + } + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); + } + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return ((float *)(tensor->data))[i]; + } + default: + { + GGML_ASSERT(false); + } + } + + return 0.0f; +} + +void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value); + return; + } + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case GGML_TYPE_I8: + return ((int8_t *) data)[0]; + case GGML_TYPE_I16: + return ((int16_t *) data)[0]; + case GGML_TYPE_I32: + return ((int32_t *) data)[0]; + case GGML_TYPE_F16: + return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); + case GGML_TYPE_F32: + return ((float *) data)[0]; + default: + GGML_ASSERT(false); + } + + return 0.0f; +} + +void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case GGML_TYPE_I8: + { + ((int8_t *)(data))[0] = value; + } break; + case GGML_TYPE_I16: + { + ((int16_t *)(data))[0] = value; + } break; + case GGML_TYPE_I32: + { + ((int32_t *)(data))[0] = value; + } break; + case GGML_TYPE_F16: + { + ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + ((float *)(data))[0] = value; + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +void * ggml_get_data(const struct ggml_tensor * tensor) { + return tensor->data; +} + +float * ggml_get_data_f32(const struct ggml_tensor * tensor) { + assert(tensor->type == GGML_TYPE_F32); + return (float *)(tensor->data); +} + +enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->op == GGML_OP_UNARY); + return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0); +} + +const char * ggml_get_name(const struct ggml_tensor * tensor) { + return tensor->name; +} + +struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) { + strncpy(tensor->name, name, sizeof(tensor->name)); + tensor->name[sizeof(tensor->name) - 1] = '\0'; + return tensor; +} + +struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...) { + va_list args; + va_start(args, fmt); + vsnprintf(tensor->name, sizeof(tensor->name), fmt, args); + va_end(args); + return tensor; +} + +struct ggml_tensor * ggml_view_tensor( + struct ggml_context * ctx, + struct ggml_tensor * src) { + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, GGML_MAX_DIMS, src->ne, src, 0); + ggml_format_name(result, "%s (view)", src->name); + + for (int i = 0; i < GGML_MAX_DIMS; i++) { + result->nb[i] = src->nb[i]; + } + + return result; +} + +struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx) { + struct ggml_object * obj = ctx->objects_begin; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TENSOR) { + return (struct ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + +struct ggml_tensor * ggml_get_next_tensor(struct ggml_context * ctx, struct ggml_tensor * tensor) { + struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE); + obj = obj->next; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TENSOR) { + return (struct ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + +struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) { + struct ggml_object * obj = ctx->objects_begin; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TENSOR) { + struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs); + if (strcmp(cur->name, name) == 0) { + return cur; + } + } + + obj = obj->next; + } + + return NULL; +} + +//////////////////////////////////////////////////////////////////////////////// + +// ggml_dup + +static struct ggml_tensor * ggml_dup_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_DUP; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_dup( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_dup_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_dup_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_dup_impl(ctx, a, true); +} + +// ggml_add + +static struct ggml_tensor * ggml_add_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_can_repeat(b, a)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + // TODO: support backward pass for broadcasting + GGML_ASSERT(ggml_are_same_shape(a, b)); + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ADD; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_add_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add_impl(ctx, a, b, true); +} + +// ggml_add_cast + +static struct ggml_tensor * ggml_add_cast_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type) { + // TODO: support less-strict constraint + // GGML_ASSERT(ggml_can_repeat(b, a)); + GGML_ASSERT(ggml_can_repeat_rows(b, a)); + GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16 + + bool is_node = false; + + if (a->grad || b->grad) { + // TODO: support backward pass for broadcasting + GGML_ASSERT(ggml_are_same_shape(a, b)); + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne); + + result->op = GGML_OP_ADD; + result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_add_cast( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type) { + return ggml_add_cast_impl(ctx, a, b, type); +} + +// ggml_add1 + +static struct ggml_tensor * ggml_add1_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_is_scalar(b)); + GGML_ASSERT(ggml_is_padded_1d(a)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ADD1; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_add1( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add1_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_add1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add1_impl(ctx, a, b, true); +} + +// ggml_acc + +static struct ggml_tensor * ggml_acc_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset, + bool inplace) { + GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(b->type == GGML_TYPE_F32); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_ACC; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_acc( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); +} + +struct ggml_tensor * ggml_acc_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); +} + +// ggml_sub + +static struct ggml_tensor * ggml_sub_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SUB; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_sub( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_sub_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_sub_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_sub_impl(ctx, a, b, true); +} + +// ggml_mul + +static struct ggml_tensor * ggml_mul_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_can_repeat(b, a)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + // TODO: support backward pass for broadcasting + GGML_ASSERT(ggml_are_same_shape(a, b)); + is_node = true; + } + + if (inplace) { + GGML_ASSERT(!is_node); + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_MUL; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_mul( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_mul_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_mul_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_mul_impl(ctx, a, b, true); +} + +// ggml_div + +static struct ggml_tensor * ggml_div_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_can_repeat(b, a)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + if (inplace) { + GGML_ASSERT(!is_node); + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_DIV; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_div( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_div_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_div_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_div_impl(ctx, a, b, true); +} + +// ggml_sqr + +static struct ggml_tensor * ggml_sqr_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SQR; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqr_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sqr_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqr_impl(ctx, a, true); +} + +// ggml_sqrt + +static struct ggml_tensor * ggml_sqrt_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SQRT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_sqrt( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqrt_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sqrt_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqrt_impl(ctx, a, true); +} + +// ggml_log + +static struct ggml_tensor * ggml_log_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_LOG; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_log( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_log_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_log_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_log_impl(ctx, a, true); +} + +// ggml_sum + +struct ggml_tensor * ggml_sum( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); + + result->op = GGML_OP_SUM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_sum_rows + +struct ggml_tensor * ggml_sum_rows( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + int64_t ne[GGML_MAX_DIMS] = { 1 }; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + ne[i] = a->ne[i]; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne); + + result->op = GGML_OP_SUM_ROWS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_mean + +struct ggml_tensor * ggml_mean( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement + is_node = true; + } + + int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_MEAN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_argmax + +struct ggml_tensor * ggml_argmax( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(ggml_is_matrix(a)); + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]); + + result->op = GGML_OP_ARGMAX; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_repeat + +struct ggml_tensor * ggml_repeat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_repeat(a, b)); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne); + + result->op = GGML_OP_REPEAT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_repeat_back + +struct ggml_tensor * ggml_repeat_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_repeat(b, a)); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + if (ggml_are_same_shape(a, b) && !is_node) { + return a; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne); + + result->op = GGML_OP_REPEAT_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_concat + +struct ggml_tensor * ggml_concat( + struct ggml_context* ctx, + struct ggml_tensor* a, + struct ggml_tensor* b) { + GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]); + + result->op = GGML_OP_CONCAT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_abs + +struct ggml_tensor * ggml_abs( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_ABS); +} + +struct ggml_tensor * ggml_abs_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ABS); +} + +// ggml_sgn + +struct ggml_tensor * ggml_sgn( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_SGN); +} + +struct ggml_tensor * ggml_sgn_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SGN); +} + +// ggml_neg + +struct ggml_tensor * ggml_neg( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_NEG); +} + +struct ggml_tensor * ggml_neg_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_NEG); +} + +// ggml_step + +struct ggml_tensor * ggml_step( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_STEP); +} + +struct ggml_tensor * ggml_step_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_STEP); +} + +// ggml_tanh + +struct ggml_tensor * ggml_tanh( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_TANH); +} + +struct ggml_tensor * ggml_tanh_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TANH); +} + +// ggml_elu + +struct ggml_tensor * ggml_elu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_ELU); +} + +struct ggml_tensor * ggml_elu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ELU); +} + +// ggml_relu + +struct ggml_tensor * ggml_relu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_RELU); +} + +struct ggml_tensor * ggml_relu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU); +} + +// ggml_leaky_relu + +struct ggml_tensor * ggml_leaky_relu( + struct ggml_context * ctx, + struct ggml_tensor * a, float negative_slope, bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + ggml_set_op_params(result, &negative_slope, sizeof(negative_slope)); + + result->op = GGML_OP_LEAKY_RELU; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_gelu + +struct ggml_tensor * ggml_gelu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_GELU); +} + +struct ggml_tensor * ggml_gelu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU); +} + +// ggml_gelu_quick + +struct ggml_tensor * ggml_gelu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_QUICK); +} + +struct ggml_tensor * ggml_gelu_quick_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_QUICK); +} + +// ggml_silu + +struct ggml_tensor * ggml_silu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_SILU); +} + +struct ggml_tensor * ggml_silu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU); +} + +// ggml_silu_back + +struct ggml_tensor * ggml_silu_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + bool is_node = false; + + if (a->grad || b->grad) { + // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SILU_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_norm + +static struct ggml_tensor * ggml_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, &eps, sizeof(eps)); + + result->op = GGML_OP_NORM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_norm_impl(ctx, a, eps, false); +} + +struct ggml_tensor * ggml_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_norm_impl(ctx, a, eps, true); +} + +// ggml_rms_norm + +static struct ggml_tensor * ggml_rms_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, &eps, sizeof(eps)); + + result->op = GGML_OP_RMS_NORM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_rms_norm_impl(ctx, a, eps, false); +} + +struct ggml_tensor * ggml_rms_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_rms_norm_impl(ctx, a, eps, true); +} + +// ggml_rms_norm_back + +struct ggml_tensor * ggml_rms_norm_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps) { + bool is_node = false; + + if (a->grad) { + // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, &eps, sizeof(eps)); + + result->op = GGML_OP_RMS_NORM_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_group_norm + +static struct ggml_tensor * ggml_group_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups, + bool inplace) { + + bool is_node = false; + if (!inplace && (a->grad)) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op_params[0] = n_groups; + + result->op = GGML_OP_GROUP_NORM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = NULL; // TODO: maybe store epsilon here? + + return result; +} + +struct ggml_tensor * ggml_group_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups) { + return ggml_group_norm_impl(ctx, a, n_groups, false); +} + +struct ggml_tensor * ggml_group_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups) { + return ggml_group_norm_impl(ctx, a, n_groups, true); +} + +// ggml_mul_mat + +struct ggml_tensor * ggml_mul_mat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_mul_mat(a, b)); + GGML_ASSERT(!ggml_is_transposed(a)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_MUL_MAT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_mul_mat_id + +struct ggml_tensor * ggml_mul_mat_id( + struct ggml_context * ctx, + struct ggml_tensor * const as[], + int n_as, + struct ggml_tensor * ids, + int id, + struct ggml_tensor * b) { + + GGML_ASSERT(ids->type == GGML_TYPE_I32); + GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); + GGML_ASSERT(ids->ne[1] == b->ne[1]); + GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]); + GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2); + GGML_ASSERT(id >= 0 && id < ids->ne[0]); + + bool is_node = false; + + if (as[0]->grad || b->grad) { + is_node = true; + } + + const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_set_op_params_i32(result, 0, id); + ggml_set_op_params_i32(result, 1, n_as); + + result->op = GGML_OP_MUL_MAT_ID; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = ids; + result->src[1] = b; + + for (int i = 0; i < n_as; i++) { + struct ggml_tensor * a = as[i]; + GGML_ASSERT(ggml_are_same_shape(as[0], a)); + GGML_ASSERT(ggml_can_mul_mat(a, b)); + GGML_ASSERT(!ggml_is_transposed(a)); + result->src[i + 2] = a; + } + + return result; +} + +// ggml_out_prod + +struct ggml_tensor * ggml_out_prod( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_out_prod(a, b)); + GGML_ASSERT(!ggml_is_transposed(a)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3] + const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_OUT_PROD; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_scale + +static struct ggml_tensor * ggml_scale_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_is_scalar(b)); + GGML_ASSERT(ggml_is_padded_1d(a)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SCALE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_scale( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_scale_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_scale_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_scale_impl(ctx, a, b, true); +} + +// ggml_set + +static struct ggml_tensor * ggml_set_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset, + bool inplace) { + GGML_ASSERT(ggml_nelements(a) >= ggml_nelements(b)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + // make a view of the destination + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_SET; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_set( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, false); +} + +struct ggml_tensor * ggml_set_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, true); +} + +struct ggml_tensor * ggml_set_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t offset) { + return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false); +} + +struct ggml_tensor * ggml_set_1d_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t offset) { + return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true); +} + +struct ggml_tensor * ggml_set_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t offset) { + return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false); +} + +struct ggml_tensor * ggml_set_2d_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t offset) { + return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true); +} + +// ggml_cpy + +static struct ggml_tensor * ggml_cpy_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + // make a view of the destination + struct ggml_tensor * result = ggml_view_tensor(ctx, b); + if (strlen(b->name) > 0) { + ggml_format_name(result, "%s (copy of %s)", b->name, a->name); + } else { + ggml_format_name(result, "%s (copy)", a->name); + } + + result->op = GGML_OP_CPY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_cpy( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_cpy_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_cpy_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_cpy_impl(ctx, a, b, true); +} + +// ggml_cont + +static struct ggml_tensor * ggml_cont_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + ggml_format_name(result, "%s (cont)", a->name); + + result->op = GGML_OP_CONT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_cont( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cont_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_cont_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cont_impl(ctx, a, true); +} + +// make contiguous, with new shape +GGML_API struct ggml_tensor * ggml_cont_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0) { + return ggml_cont_4d(ctx, a, ne0, 1, 1, 1); +} + +GGML_API struct ggml_tensor * ggml_cont_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1) { + return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1); +} + +GGML_API struct ggml_tensor * ggml_cont_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1); +} + +struct ggml_tensor * ggml_cont_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3)); + + bool is_node = false; + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); + ggml_format_name(result, "%s (cont)", a->name); + + result->op = GGML_OP_CONT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_reshape + +struct ggml_tensor * ggml_reshape( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_is_contiguous(a)); + // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous. + GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + if (b->grad) { + // gradient propagation is not supported + //GGML_ASSERT(false); + } + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b->ne, a, 0); + ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_reshape_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[1] = { ne0 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0); + ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_reshape_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[2] = { ne0, ne1 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0); + ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_reshape_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[3] = { ne0, ne1, ne2 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0); + ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_reshape_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0); + ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +static struct ggml_tensor * ggml_view_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_dims, + const int64_t * ne, + size_t offset) { + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset); + ggml_format_name(result, "%s (view)", a->name); + + ggml_set_op_params(result, &offset, sizeof(offset)); + + result->op = GGML_OP_VIEW; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_view_1d + +struct ggml_tensor * ggml_view_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + size_t offset) { + + struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset); + + return result; +} + +// ggml_view_2d + +struct ggml_tensor * ggml_view_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + size_t nb1, + size_t offset) { + + const int64_t ne[2] = { ne0, ne1 }; + + struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset); + + result->nb[1] = nb1; + result->nb[2] = result->nb[1]*ne1; + result->nb[3] = result->nb[2]; + + return result; +} + +// ggml_view_3d + +struct ggml_tensor * ggml_view_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + size_t nb1, + size_t nb2, + size_t offset) { + + const int64_t ne[3] = { ne0, ne1, ne2 }; + + struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset); + + result->nb[1] = nb1; + result->nb[2] = nb2; + result->nb[3] = result->nb[2]*ne2; + + return result; +} + +// ggml_view_4d + +struct ggml_tensor * ggml_view_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + + struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset); + + result->nb[1] = nb1; + result->nb[2] = nb2; + result->nb[3] = nb3; + + return result; +} + +// ggml_permute + +struct ggml_tensor * ggml_permute( + struct ggml_context * ctx, + struct ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3) { + GGML_ASSERT(axis0 >= 0 && axis0 < GGML_MAX_DIMS); + GGML_ASSERT(axis1 >= 0 && axis1 < GGML_MAX_DIMS); + GGML_ASSERT(axis2 >= 0 && axis2 < GGML_MAX_DIMS); + GGML_ASSERT(axis3 >= 0 && axis3 < GGML_MAX_DIMS); + + GGML_ASSERT(axis0 != axis1); + GGML_ASSERT(axis0 != axis2); + GGML_ASSERT(axis0 != axis3); + GGML_ASSERT(axis1 != axis2); + GGML_ASSERT(axis1 != axis3); + GGML_ASSERT(axis2 != axis3); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + ggml_format_name(result, "%s (permuted)", a->name); + + int ne[GGML_MAX_DIMS]; + int nb[GGML_MAX_DIMS]; + + ne[axis0] = a->ne[0]; + ne[axis1] = a->ne[1]; + ne[axis2] = a->ne[2]; + ne[axis3] = a->ne[3]; + + nb[axis0] = a->nb[0]; + nb[axis1] = a->nb[1]; + nb[axis2] = a->nb[2]; + nb[axis3] = a->nb[3]; + + result->ne[0] = ne[0]; + result->ne[1] = ne[1]; + result->ne[2] = ne[2]; + result->ne[3] = ne[3]; + + result->nb[0] = nb[0]; + result->nb[1] = nb[1]; + result->nb[2] = nb[2]; + result->nb[3] = nb[3]; + + result->op = GGML_OP_PERMUTE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + int32_t params[] = { axis0, axis1, axis2, axis3 }; + ggml_set_op_params(result, params, sizeof(params)); + + return result; +} + +// ggml_transpose + +struct ggml_tensor * ggml_transpose( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + ggml_format_name(result, "%s (transposed)", a->name); + + result->ne[0] = a->ne[1]; + result->ne[1] = a->ne[0]; + + result->nb[0] = a->nb[1]; + result->nb[1] = a->nb[0]; + + result->op = GGML_OP_TRANSPOSE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_get_rows + +struct ggml_tensor * ggml_get_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(a->ne[2] == b->ne[1]); + GGML_ASSERT(b->ne[3] == 1); + GGML_ASSERT(b->type == GGML_TYPE_I32); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + // TODO: implement non F32 return + //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); + + result->op = GGML_OP_GET_ROWS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_get_rows_back + +struct ggml_tensor * ggml_get_rows_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c) { + GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_matrix(c) && (a->ne[0] == c->ne[0])); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + // TODO: implement non F32 return + //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c->ne[0], c->ne[1]); + + result->op = GGML_OP_GET_ROWS_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_diag + +struct ggml_tensor * ggml_diag( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(a->ne[1] == 1); + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 4, ne); + + result->op = GGML_OP_DIAG; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_diag_mask_inf + +static struct ggml_tensor * ggml_diag_mask_inf_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + bool inplace) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + int32_t params[] = { n_past }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_DIAG_MASK_INF; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_diag_mask_inf( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past) { + return ggml_diag_mask_inf_impl(ctx, a, n_past, false); +} + +struct ggml_tensor * ggml_diag_mask_inf_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past) { + return ggml_diag_mask_inf_impl(ctx, a, n_past, true); +} + +// ggml_diag_mask_zero + +static struct ggml_tensor * ggml_diag_mask_zero_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + bool inplace) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + int32_t params[] = { n_past }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_DIAG_MASK_ZERO; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_diag_mask_zero( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past) { + return ggml_diag_mask_zero_impl(ctx, a, n_past, false); +} + +struct ggml_tensor * ggml_diag_mask_zero_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past) { + return ggml_diag_mask_zero_impl(ctx, a, n_past, true); +} + +// ggml_soft_max + +static struct ggml_tensor * ggml_soft_max_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + bool inplace) { + GGML_ASSERT(ggml_is_contiguous(a)); + if (mask) { + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + } + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + float params[] = { scale }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_SOFT_MAX; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = mask; + + return result; +} + +struct ggml_tensor * ggml_soft_max( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false); +} + +struct ggml_tensor * ggml_soft_max_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true); +} + +struct ggml_tensor * ggml_soft_max_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale) { + return ggml_soft_max_impl(ctx, a, mask, scale, false); +} + +// ggml_soft_max_back + +static struct ggml_tensor * ggml_soft_max_back_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; // TODO : implement backward pass + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SOFT_MAX_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_soft_max_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_soft_max_back_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_soft_max_back_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_soft_max_back_impl(ctx, a, b, true); +} + +// ggml_rope + +static struct ggml_tensor * ggml_rope_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + float xpos_base, + bool xpos_down, + bool inplace) { + GGML_ASSERT(ggml_is_vector(b)); + GGML_ASSERT(b->type == GGML_TYPE_I32); + GGML_ASSERT(a->ne[2] == b->ne[0]); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx }; + memcpy(params + 5, &freq_base, sizeof(float)); + memcpy(params + 6, &freq_scale, sizeof(float)); + memcpy(params + 7, &ext_factor, sizeof(float)); + memcpy(params + 8, &attn_factor, sizeof(float)); + memcpy(params + 9, &beta_fast, sizeof(float)); + memcpy(params + 10, &beta_slow, sizeof(float)); + memcpy(params + 11, &xpos_base, sizeof(float)); + memcpy(params + 12, &xpos_down, sizeof(bool)); + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_ROPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_rope( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx) { + return ggml_rope_impl( + ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false + ); +} + +struct ggml_tensor * ggml_rope_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx) { + return ggml_rope_impl( + ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true + ); +} + +struct ggml_tensor * ggml_rope_custom( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false + ); +} + +struct ggml_tensor * ggml_rope_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true + ); +} + +struct ggml_tensor * ggml_rope_xpos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + float base, + bool down) { + return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true); +} + +// ggml_rope_back + +struct ggml_tensor * ggml_rope_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + float xpos_base, + bool xpos_down) { + GGML_ASSERT(ggml_is_vector(b)); + GGML_ASSERT(b->type == GGML_TYPE_I32); + GGML_ASSERT(a->ne[2] == b->ne[0]); + + GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); + + bool is_node = false; + + if (a->grad) { + is_node = false; // TODO: implement backward + } + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx }; + memcpy(params + 5, &freq_base, sizeof(float)); + memcpy(params + 6, &freq_scale, sizeof(float)); + memcpy(params + 7, &ext_factor, sizeof(float)); + memcpy(params + 8, &attn_factor, sizeof(float)); + memcpy(params + 9, &beta_fast, sizeof(float)); + memcpy(params + 10, &beta_slow, sizeof(float)); + memcpy(params + 11, &xpos_base, sizeof(float)); + memcpy(params + 12, &xpos_down, sizeof(bool)); + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_ROPE_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_alibi + +struct ggml_tensor * ggml_alibi( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_head, + float bias_max) { + GGML_ASSERT(n_past >= 0); + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + int32_t op_params[3] = { n_past, n_head }; + memcpy(op_params + 2, &bias_max, sizeof(float)); + ggml_set_op_params(result, op_params, sizeof(op_params)); + + result->op = GGML_OP_ALIBI; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_clamp + +struct ggml_tensor * ggml_clamp( + struct ggml_context * ctx, + struct ggml_tensor * a, + float min, + float max) { + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + float params[] = { min, max }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CLAMP; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_conv_1d + +static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; +} + +GGML_API struct ggml_tensor * ggml_conv_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K] + + struct ggml_tensor * result = + ggml_mul_mat(ctx, + ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K] + ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K] + + result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL] + + return result; +} + +// ggml_conv_1d_ph + +struct ggml_tensor* ggml_conv_1d_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s, + int d) { + return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); +} + +// ggml_conv_transpose_1d + +static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { + return (ins - 1) * s - 2 * p + d * (ks - 1) + 1; +} + +GGML_API struct ggml_tensor * ggml_conv_transpose_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { + GGML_ASSERT(ggml_is_matrix(b)); + GGML_ASSERT(a->ne[2] == b->ne[1]); + GGML_ASSERT(a->ne[3] == 1); + + GGML_ASSERT(p0 == 0); + GGML_ASSERT(d0 == 1); + + bool is_node = false; + + if (a->grad || b->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { + ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/), + a->ne[1], b->ne[2], 1, + }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + int32_t params[] = { s0, p0, d0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CONV_TRANSPOSE_1D; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_conv_2d + +// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] +// a: [OC,IC, KH, KW] +// b: [N, IC, IH, IW] +// result: [N, OH, OW, IC*KH*KW] +struct ggml_tensor * ggml_im2col( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D) { + + if(is_2D) { + GGML_ASSERT(a->ne[2] == b->ne[2]); + } else { + GGML_ASSERT(a->ne[1] == b->ne[1]); + } + bool is_node = false; + + if (a->grad || b->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; + const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + + const int64_t ne[4] = { + is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0], + OW, + is_2D ? OH : b->ne[2], + is_2D ? b->ne[3] : 1, + }; + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_IM2COL; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// a: [OC,IC, KH, KW] +// b: [N, IC, IH, IW] +// result: [N, OC, OH, OW] +struct ggml_tensor * ggml_conv_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW] + + struct ggml_tensor * result = + ggml_mul_mat(ctx, + ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW] + ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW] + + result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW] + + return result; +} + +// ggml_conv_2d_sk_p0 +struct ggml_tensor * ggml_conv_2d_sk_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1); +} + +// ggml_conv_2d_s1_ph + +struct ggml_tensor * ggml_conv_2d_s1_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1); +} + +// ggml_conv_transpose_2d_p0 + +static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { + return (ins - 1) * s - 2 * p + ks; +} + +struct ggml_tensor * ggml_conv_transpose_2d_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int stride) { + GGML_ASSERT(a->ne[3] == b->ne[2]); + + bool is_node = false; + + if (a->grad || b->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { + ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/), + ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/), + a->ne[2], b->ne[3], + }; + + struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_set_op_params_i32(result, 0, stride); + + result->op = GGML_OP_CONV_TRANSPOSE_2D; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_pool_* + +static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) { + return (ins + 2 * p - ks) / s + 1; +} + +// ggml_pool_1d + +struct ggml_tensor * ggml_pool_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_op_pool op, + int k0, + int s0, + int p0) { + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[2] = { + ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), + a->ne[1], + }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + + int32_t params[] = { op, k0, s0, p0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_POOL_1D; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_pool_2d + +struct ggml_tensor * ggml_pool_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + float p0, + float p1) { + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[3] = { + ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), + ggml_calc_pool_output_size(a->ne[1], k1, s1, p1), + a->ne[2], + }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); + + int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_POOL_2D; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_upscale + +static struct ggml_tensor * ggml_upscale_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int scale_factor) { + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, + a->ne[0] * scale_factor, + a->ne[1] * scale_factor, + a->ne[2], a->ne[3]); + + result->op = GGML_OP_UPSCALE; + result->op_params[0] = scale_factor; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = NULL; + + return result; +} + +struct ggml_tensor * ggml_pad( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, int p1, int p2, int p3) { + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, + a->ne[0] + p0, + a->ne[1] + p1, + a->ne[2] + p2, + a->ne[3] + p3); + + result->op = GGML_OP_PAD; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_upscale( + struct ggml_context * ctx, + struct ggml_tensor * a, + int scale_factor) { + return ggml_upscale_impl(ctx, a, scale_factor); +} + +// ggml_argsort + +struct ggml_tensor * ggml_argsort( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_sort_order order) { + bool is_node = false; + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); + + ggml_set_op_params_i32(result, 0, (int32_t) order); + + result->op = GGML_OP_ARGSORT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_top_k + +struct ggml_tensor * ggml_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k) { + GGML_ASSERT(a->ne[0] >= k); + + struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_DESC); + + result = ggml_view_4d(ctx, result, + k, result->ne[1], result->ne[2], result->ne[3], + result->nb[1], result->nb[2], result->nb[3], + 0); + + return result; +} + +// ggml_flash_attn + +struct ggml_tensor * ggml_flash_attn( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + bool masked) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); + + int32_t t = masked ? 1 : 0; + ggml_set_op_params(result, &t, sizeof(t)); + + result->op = GGML_OP_FLASH_ATTN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + + return result; +} + +// ggml_flash_ff + +struct ggml_tensor * ggml_flash_ff( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b0, + struct ggml_tensor * b1, + struct ggml_tensor * c0, + struct ggml_tensor * c1) { + GGML_ASSERT(ggml_can_mul_mat(b0, a)); + // TODO: more checks + + bool is_node = false; + + if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) { + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne); + + result->op = GGML_OP_FLASH_FF; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b0; + result->src[2] = b1; + result->src[3] = c0; + result->src[4] = c1; + + return result; +} + +// ggml_flash_attn_back + +struct ggml_tensor * ggml_flash_attn_back( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * d, + bool masked) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + + // d shape [D,N,ne2,ne3] + // q shape [D,N,ne2,ne3] + // k shape [D,M,kvne2,ne3] + // v shape [M,D,kvne2,ne3] + + const int64_t D = q->ne[0]; + const int64_t N = q->ne[1]; + const int64_t M = k->ne[1]; + const int64_t ne2 = q->ne[2]; + const int64_t ne3 = q->ne[3]; + const int64_t kvne2 = k->ne[2]; + + GGML_ASSERT(k->ne[0] == D); + GGML_ASSERT(v->ne[0] == M); + GGML_ASSERT(v->ne[1] == D); + GGML_ASSERT(d->ne[0] == D); + GGML_ASSERT(d->ne[1] == N); + GGML_ASSERT(k->ne[2] == kvne2); + GGML_ASSERT(k->ne[3] == ne3); + GGML_ASSERT(v->ne[2] == kvne2); + GGML_ASSERT(v->ne[3] == ne3); + GGML_ASSERT(d->ne[2] == ne2); + GGML_ASSERT(d->ne[3] == ne3); + + GGML_ASSERT(ne2 % kvne2 == 0); + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + // when using this operation (in backwards pass) these grads are set. + // we don't want to create (big) grad of our result, so is_node is false. + is_node = false; + } + + // store gradients of q, k and v as continuous tensors concatenated in result. + // note: v and gradv are actually transposed, i.e. v->ne[0] != D. + const int64_t elem_q = ggml_nelements(q); + const int64_t elem_k = ggml_nelements(k); + const int64_t elem_v = ggml_nelements(v); + + enum ggml_type result_type = GGML_TYPE_F32; + GGML_ASSERT(ggml_blck_size(result_type) == 1); + const size_t tsize = ggml_type_size(result_type); + + const size_t offs_q = 0; + const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); + const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); + const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN); + + const size_t nelements = (end + tsize - 1)/tsize; + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements); + + int32_t masked_i = masked ? 1 : 0; + ggml_set_op_params(result, &masked_i, sizeof(masked_i)); + + result->op = GGML_OP_FLASH_ATTN_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = d; + + return result; +} + +// ggml_win_part + +struct ggml_tensor * ggml_win_part( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w) { + GGML_ASSERT(a->ne[3] == 1); + GGML_ASSERT(a->type == GGML_TYPE_F32); + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + // padding + const int px = (w - a->ne[1]%w)%w; + const int py = (w - a->ne[2]%w)%w; + + const int npx = (px + a->ne[1])/w; + const int npy = (py + a->ne[2])/w; + const int np = npx*npy; + + const int64_t ne[4] = { a->ne[0], w, w, np, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + int32_t params[] = { npx, npy, w }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_WIN_PART; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_win_unpart + +struct ggml_tensor * ggml_win_unpart( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w0, + int h0, + int w) { + GGML_ASSERT(a->type == GGML_TYPE_F32); + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], w0, h0, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); + + int32_t params[] = { w }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_WIN_UNPART; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_get_rel_pos + +struct ggml_tensor * ggml_get_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + int qh, + int kh) { + GGML_ASSERT(qh == kh); + GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]); + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], kh, qh, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne); + + result->op = GGML_OP_GET_REL_POS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = NULL; + + return result; +} + +// ggml_add_rel_pos + +static struct ggml_tensor * ggml_add_rel_pos_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph, + bool inplace) { + GGML_ASSERT(ggml_are_same_shape(pw, ph)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_is_contiguous(pw)); + GGML_ASSERT(ggml_is_contiguous(ph)); + GGML_ASSERT(ph->type == GGML_TYPE_F32); + GGML_ASSERT(pw->type == GGML_TYPE_F32); + GGML_ASSERT(pw->ne[3] == a->ne[2]); + GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]); + GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]); + + bool is_node = false; + + if (!inplace && (a->grad || pw->grad || ph->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + ggml_set_op_params_i32(result, 0, inplace ? 1 : 0); + + result->op = GGML_OP_ADD_REL_POS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = pw; + result->src[2] = ph; + + return result; +} + +struct ggml_tensor * ggml_add_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph) { + return ggml_add_rel_pos_impl(ctx, a, pw, ph, false); +} + +struct ggml_tensor * ggml_add_rel_pos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph) { + return ggml_add_rel_pos_impl(ctx, a, pw, ph, true); +} + +// gmml_unary + +static struct ggml_tensor * ggml_unary_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, (int32_t) op); + + result->op = GGML_OP_UNARY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op) { + return ggml_unary_impl(ctx, a, op, false); +} + +struct ggml_tensor * ggml_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op) { + return ggml_unary_impl(ctx, a, op, true); +} + +// ggml_map_unary + +static struct ggml_tensor * ggml_map_unary_impl_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_unary_op_f32_t fun, + bool inplace) { + bool is_node = false; + + if (!inplace && a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); + + result->op = GGML_OP_MAP_UNARY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_map_unary_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_unary_op_f32_t fun) { + return ggml_map_unary_impl_f32(ctx, a, fun, false); +} + +struct ggml_tensor * ggml_map_unary_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_unary_op_f32_t fun) { + return ggml_map_unary_impl_f32(ctx, a, fun, true); +} + +// ggml_map_binary + +static struct ggml_tensor * ggml_map_binary_impl_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_binary_op_f32_t fun, + bool inplace) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); + + result->op = GGML_OP_MAP_BINARY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_map_binary_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_binary_op_f32_t fun) { + return ggml_map_binary_impl_f32(ctx, a, b, fun, false); +} + +struct ggml_tensor * ggml_map_binary_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_binary_op_f32_t fun) { + return ggml_map_binary_impl_f32(ctx, a, b, fun, true); +} + +// ggml_map_custom1_f32 + +static struct ggml_tensor * ggml_map_custom1_impl_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_f32_t fun, + bool inplace) { + bool is_node = false; + + if (!inplace && a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); + + result->op = GGML_OP_MAP_CUSTOM1_F32; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_map_custom1_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_f32_t fun) { + return ggml_map_custom1_impl_f32(ctx, a, fun, false); +} + +struct ggml_tensor * ggml_map_custom1_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_f32_t fun) { + return ggml_map_custom1_impl_f32(ctx, a, fun, true); +} + +// ggml_map_custom2_f32 + +static struct ggml_tensor * ggml_map_custom2_impl_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_f32_t fun, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); + + result->op = GGML_OP_MAP_CUSTOM2_F32; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_map_custom2_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_f32_t fun) { + return ggml_map_custom2_impl_f32(ctx, a, b, fun, false); +} + +struct ggml_tensor * ggml_map_custom2_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_f32_t fun) { + return ggml_map_custom2_impl_f32(ctx, a, b, fun, true); +} + +// ggml_map_custom3_f32 + +static struct ggml_tensor * ggml_map_custom3_impl_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_f32_t fun, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad || b->grad || c->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); + + result->op = GGML_OP_MAP_CUSTOM3_F32; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + +struct ggml_tensor * ggml_map_custom3_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_f32_t fun) { + return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false); +} + +struct ggml_tensor * ggml_map_custom3_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_f32_t fun) { + return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true); +} + +// ggml_map_custom1 +struct ggml_map_custom1_op_params { + ggml_custom1_op_t fun; + int n_tasks; + void * userdata; +}; + +static struct ggml_tensor * ggml_map_custom1_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { + GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); + + bool is_node = false; + + if (!inplace && a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + struct ggml_map_custom1_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + + result->op = GGML_OP_MAP_CUSTOM1; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_map_custom1( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false); +} + +struct ggml_tensor * ggml_map_custom1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true); +} + +// ggml_map_custom2 + +struct ggml_map_custom2_op_params { + ggml_custom2_op_t fun; + int n_tasks; + void * userdata; +}; + +static struct ggml_tensor * ggml_map_custom2_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { + GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + struct ggml_map_custom2_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + + result->op = GGML_OP_MAP_CUSTOM2; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_map_custom2( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false); +} + +struct ggml_tensor * ggml_map_custom2_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true); +} + +// ggml_map_custom3 + +struct ggml_map_custom3_op_params { + ggml_custom3_op_t fun; + int n_tasks; + void * userdata; +}; + +static struct ggml_tensor * ggml_map_custom3_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { + GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad || c->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + struct ggml_map_custom3_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + + result->op = GGML_OP_MAP_CUSTOM3; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + +struct ggml_tensor * ggml_map_custom3( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false); +} + +struct ggml_tensor * ggml_map_custom3_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true); +} + +// ggml_cross_entropy_loss + +struct ggml_tensor * ggml_cross_entropy_loss( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); + + result->op = GGML_OP_CROSS_ENTROPY_LOSS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_cross_entropy_loss_back + +struct ggml_tensor * ggml_cross_entropy_loss_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_is_scalar(c)); + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK; + result->grad = NULL; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ggml_set_param( + struct ggml_context * ctx, + struct ggml_tensor * tensor) { + tensor->is_param = true; + + GGML_ASSERT(tensor->grad == NULL); + tensor->grad = ggml_dup_tensor(ctx, tensor); + ggml_format_name(tensor->grad, "%s (grad)", tensor->name); +} + +// ggml_compute_forward_dup + +static void ggml_compute_forward_dup_same_cont( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == dst->type); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const size_t nb00 = src0->nb[0]; + const size_t nb0 = dst->nb[0]; + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + // parallelize by elements + const int ne = ggml_nelements(dst); + const int dr = (ne + nth - 1) / nth; + const int ie0 = dr * ith; + const int ie1 = MIN(ie0 + dr, ne); + + if (ie0 < ie1) { + memcpy( + ((char *) dst->data + ie0*nb0), + ((char *) src0->data + ie0*nb00), + (ie1 - ie0) * ggml_type_size(src0->type)); + } + +} +static void ggml_compute_forward_dup_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { + ggml_compute_forward_dup_same_cont(params, src0, dst); + return; + } + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { + // copy by rows + const size_t rs = ne00*nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy + + if (ggml_is_contiguous(dst)) { + if (nb00 == sizeof(ggml_fp16_t)) { + if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + const size_t rs = ne00 * nb00; + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + for (int i00 = 0; i00 < ne00; i00++) { + dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (type_traits[dst->type].from_float) { + ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; + float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + size_t id = 0; + size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + for (int i00 = 0; i00 < ne00; i00++) { + src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]); + } + + quantize_row_q(src0_f32, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } + return; + } + + // dst counters + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + if (dst->type == GGML_TYPE_F16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t)); + + if (++i10 == ne00) { + i10 = 0; + if (++i11 == ne01) { + i11 = 0; + if (++i12 == ne02) { + i12 = 0; + if (++i13 == ne03) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } +} + +static void ggml_compute_forward_dup_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { + ggml_compute_forward_dup_same_cont(params, src0, dst); + return; + } + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { + // copy by rows + const size_t rs = ne00*nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + if (ggml_is_contiguous(dst)) { + // TODO: simplify + if (nb00 == sizeof(float)) { + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + const size_t rs = ne00 * nb00; + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (type_traits[dst->type].from_float) { + ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; + + size_t id = 0; + size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + quantize_row_q(src0_ptr, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } + + return; + } + + // dst counters + + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + if (dst->type == GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, sizeof(float)); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } +} + +static void ggml_compute_forward_dup( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { + ggml_compute_forward_dup_same_cont(params, src0, dst); + return; + } + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_dup_f16(params, src0, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_dup_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_add + +static void ggml_compute_forward_add_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { +#ifdef GGML_USE_ACCELERATE + vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10); +#else + ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); +#endif + } + } + } else { + // src1 is not contiguous + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); + + dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; + } + } + } +} + +static void ggml_compute_forward_add_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + if (dst->type == GGML_TYPE_F32) { + GGML_ASSERT( nb0 == sizeof(float)); + } + else { + GGML_ASSERT(dst->type == GGML_TYPE_F16); + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + } + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + if (dst->type == GGML_TYPE_F16) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); + } + } + } else { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; + } + } + } + } + else { + // src1 is not contiguous + GGML_ASSERT(false); + } +} + +static void ggml_compute_forward_add_f16_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(ggml_fp16_t)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i])); + } + } + } + else { + // src1 is not contiguous + GGML_ASSERT(false); + } +} + +static void ggml_compute_forward_add_q_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + const enum ggml_type dtype = dst->type; + ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; + ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ggml_is_quantized(src0->type)); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + // src1 and dst are same shape as src0 => same indices + const int i13 = i03; + const int i12 = i02; + const int i11 = i01; + + const int i3 = i03; + const int i2 = i02; + const int i1 = i01; + + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)); + void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + assert(ne00 % 32 == 0); + + // unquantize row from src0 to temp buffer + dequantize_row_q(src0_row, wdata, ne00); + // add src1 + ggml_vec_acc_f32(ne00, wdata, src1_row); + // quantize row to dst + if (quantize_row_q != NULL) { + quantize_row_q(wdata, dst_row, ne00); + } else { + memcpy(dst_row, wdata, ne0*nb0); + } + } +} + +static void ggml_compute_forward_add( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_add_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + { + if (src1->type == GGML_TYPE_F16) { + ggml_compute_forward_add_f16_f16(params, src0, src1, dst); + } + else if (src1->type == GGML_TYPE_F32) { + ggml_compute_forward_add_f16_f32(params, src0, src1, dst); + } + else { + GGML_ASSERT(false); + } + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + { + ggml_compute_forward_add_q_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_add1 + +static void ggml_compute_forward_add1_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + +#ifdef GGML_USE_ACCELERATE + UNUSED(ggml_vec_add1_f32); + + vDSP_vadd( + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, + (float *) ((char *) src1->data), 0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, + ne0); +#else + ggml_vec_add1_f32(ne0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), + *(float *) src1->data); +#endif + } +} + +static void ggml_compute_forward_add1_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // scalar to add + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void ggml_compute_forward_add1_f16_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // scalar to add + const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void ggml_compute_forward_add1_q_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // scalar to add + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + + const enum ggml_type type = src0->type; + ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; + ggml_from_float_t const quantize_row_q = type_traits[type].from_float; + + // we don't support permuted src0 + GGML_ASSERT(nb00 == ggml_type_size(type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ggml_is_quantized(src0->type)); + GGML_ASSERT(dst->type == src0->type); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03)); + void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 )); + + assert(ne0 % 32 == 0); + + // unquantize row from src0 to temp buffer + dequantize_row_q(src0_row, wdata, ne0); + // add src1 + ggml_vec_acc1_f32(ne0, wdata, v); + // quantize row to dst + quantize_row_q(wdata, dst_row, ne0); + } +} + +static void ggml_compute_forward_add1( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_add1_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + { + if (src1->type == GGML_TYPE_F16) { + ggml_compute_forward_add1_f16_f16(params, src0, src1, dst); + } + else if (src1->type == GGML_TYPE_F32) { + ggml_compute_forward_add1_f16_f32(params, src0, src1, dst); + } + else { + GGML_ASSERT(false); + } + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + { + ggml_compute_forward_add1_q_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_acc + +static void ggml_compute_forward_acc_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + + // view src0 and dst with these strides and data offset inbytes during acc + // nb0 is implicitly element_size because src0 and dst are contiguous + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace && (params->type == GGML_TASK_INIT)) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + ggml_nbytes(dst)); + } + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src1); + const int nc = src1->ne[0]; + + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) + + // src0 and dst as viewed during acc + const size_t nb0 = ggml_element_size(src0); + + const size_t nb00 = nb0; + const size_t nb01 = nb1; + const size_t nb02 = nb2; + const size_t nb03 = nb3; + + GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < ggml_nbytes(dst)); + GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0)); + + GGML_ASSERT(nb10 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are viewed with shape of src1 and offset + // => same indices + const int i3 = ir/(ne12*ne11); + const int i2 = (ir - i3*ne12*ne11)/ne11; + const int i1 = (ir - i3*ne12*ne11 - i2*ne11); + +#ifdef GGML_USE_ACCELERATE + vDSP_vadd( + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1, + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc); +#else + ggml_vec_add_f32(nc, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); +#endif + } +} + +static void ggml_compute_forward_acc( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_acc_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_sub + +static void ggml_compute_forward_sub_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + if (nb10 == sizeof(float)) { + for (int ir = 0; ir < nr; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + +#ifdef GGML_USE_ACCELERATE + vDSP_vsub( + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, + ne0); +#else + ggml_vec_sub_f32(ne0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); +#endif + // } + // } + } + } else { + // src1 is not contiguous + for (int ir = 0; ir < nr; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i0 = 0; i0 < ne0; i0++) { + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); + + dst_ptr[i0] = src0_ptr[i0] - *src1_ptr; + } + } + } +} + +static void ggml_compute_forward_sub( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sub_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_mul + +static void ggml_compute_forward_mul_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + const int ith = params->ith; + const int nth = params->nth; + +#ifdef GGML_USE_CLBLAST + if (src1->backend == GGML_BACKEND_GPU) { + // TODO: OpenCL kernel support full broadcast + GGML_ASSERT(ggml_can_repeat_rows(src1, src0)); + if (ith == 0) { + ggml_cl_mul(src0, src1, dst); + } + return; + } +#endif + + const int64_t nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + if (nb10 == sizeof(float)) { + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0 ; r < nr0; ++r) { +#ifdef GGML_USE_ACCELERATE + UNUSED(ggml_vec_mul_f32); + + vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10); +#else + ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); +#endif + } + } + } else { + // src1 is not contiguous + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne00; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); + + dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); + } + } + } +} + +static void ggml_compute_forward_mul( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now"); + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_mul_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_div + +static void ggml_compute_forward_div_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + if (nb10 == sizeof(float)) { + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { +#ifdef GGML_USE_ACCELERATE + UNUSED(ggml_vec_div_f32); + + vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); +#else + ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); +#endif + } + } + } else { + // src1 is not contiguous + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne00; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); + + dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr); + } + } + } +} + +static void ggml_compute_forward_div( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_div_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_sqr + +static void ggml_compute_forward_sqr_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sqr_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sqr( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sqr_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_sqrt + +static void ggml_compute_forward_sqrt_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sqrt_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sqrt( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sqrt_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_log + +static void ggml_compute_forward_log_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + GGML_ASSERT( dst->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_log_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_log( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_log_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_sum + +static void ggml_compute_forward_sum_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_is_scalar(dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(ggml_is_scalar(dst)); + assert(src0->nb[0] == sizeof(float)); + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + + ggml_float sum = 0; + ggml_float row_sum = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f32_ggf(ne00, + &row_sum, + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + sum += row_sum; + } + } + } + ((float *) dst->data)[0] = sum; +} + +static void ggml_compute_forward_sum_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_is_scalar(dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(ggml_fp16_t)); + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + + float sum = 0; + float row_sum = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f16_ggf(ne00, + &row_sum, + (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); + sum += row_sum; + } + } + } + ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum); +} + +static void ggml_compute_forward_sum( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sum_f32(params, src0, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_sum_f16(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_sum_rows + +static void ggml_compute_forward_sum_rows_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(dst->nb[0] == sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == 1); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + float row_sum = 0; + ggml_vec_sum_f32(ne00, &row_sum, src_row); + dst_row[0] = row_sum; + } + } + } +} + +static void ggml_compute_forward_sum_rows( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sum_rows_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_mean + +static void ggml_compute_forward_mean_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS + + assert(ne0 == 1); + assert(ne1 == ne01); + assert(ne2 == ne02); + assert(ne3 == ne03); + + UNUSED(ne0); + UNUSED(ne1); + UNUSED(ne2); + UNUSED(ne3); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f32(ne00, + (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + + *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00; + } + } + } +} + +static void ggml_compute_forward_mean( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_mean_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_argmax + +static void ggml_compute_forward_argmax_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(float)); + assert(dst->nb[0] == sizeof(float)); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + + const size_t nb01 = src0->nb[1]; + const size_t nb0 = dst->nb[0]; + + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src = (float *) ((char *) src0->data + i1*nb01); + int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0); + int v = 0; + ggml_vec_argmax_f32(ne00, &v, src); + dst_[0] = v; + } +} + +static void ggml_compute_forward_argmax( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_argmax_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_repeat + +static void ggml_compute_forward_repeat_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_can_repeat(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS + + // guaranteed to be an integer due to the check in ggml_can_repeat + const int nr0 = (int)(ne0/ne00); + const int nr1 = (int)(ne1/ne01); + const int nr2 = (int)(ne2/ne02); + const int nr3 = (int)(ne3/ne03); + + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne03; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne02; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne01; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + ggml_vec_cpy_f32(ne00, + (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0), + (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01)); + } + } + } + } + } + } + } +} + +static void ggml_compute_forward_repeat_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_can_repeat(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS + + // guaranteed to be an integer due to the check in ggml_can_repeat + const int nr0 = (int)(ne0/ne00); + const int nr1 = (int)(ne1/ne01); + const int nr2 = (int)(ne2/ne02); + const int nr3 = (int)(ne3/ne03); + + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne03; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne02; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne01; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0); + ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01); + // ggml_vec_cpy_f16(ne00, y, x) + for (int i = 0; i < ne00; ++i) { + y[i] = x[i]; + } + } + } + } + } + } + } + } +} + +static void ggml_compute_forward_repeat( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_repeat_f16(params, src0, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_repeat_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_repeat_back + +static void ggml_compute_forward_repeat_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_can_repeat(dst, src0)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS + + // guaranteed to be an integer due to the check in ggml_can_repeat + const int nr0 = (int)(ne00/ne0); + const int nr1 = (int)(ne01/ne1); + const int nr2 = (int)(ne02/ne2); + const int nr3 = (int)(ne03/ne3); + + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + if (ggml_is_contiguous(dst)) { + ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + } else { + for (int k3 = 0; k3 < ne3; k3++) { + for (int k2 = 0; k2 < ne2; k2++) { + for (int k1 = 0; k1 < ne1; k1++) { + ggml_vec_set_f32(ne0, + (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3), + 0); + } + } + } + } + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne3; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne2; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne1; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + ggml_vec_acc_f32(ne0, + (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1), + (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00)); + } + } + } + } + } + } + } +} + +static void ggml_compute_forward_repeat_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_repeat_back_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_concat + +static void ggml_compute_forward_concat_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ith; i2 < ne2; i2 += nth) { + if (i2 < ne02) { // src0 + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03); + + float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); + *y = *x; + } + } + } // src1 + else { + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13); + + float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); + *y = *x; + } + } + } + } + } +} + +static void ggml_compute_forward_concat( + const struct ggml_compute_params* params, + const struct ggml_tensor* src0, + const struct ggml_tensor* src1, + struct ggml_tensor* dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_concat_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_abs + +static void ggml_compute_forward_abs_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_abs_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_abs( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_abs_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_sgn + +static void ggml_compute_forward_sgn_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sgn_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sgn( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sgn_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_neg + +static void ggml_compute_forward_neg_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_neg_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_neg( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_neg_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_step + +static void ggml_compute_forward_step_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_step_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_step( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_step_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_tanh + +static void ggml_compute_forward_tanh_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_tanh_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_tanh( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_tanh_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_elu + +static void ggml_compute_forward_elu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_elu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_elu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_elu_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_relu + +static void ggml_compute_forward_relu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_relu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_relu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_relu_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_gelu + +static void ggml_compute_forward_gelu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_gelu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gelu_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_gelu_quick + +static void ggml_compute_forward_gelu_quick_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_quick_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_gelu_quick( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gelu_quick_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_silu + +static void ggml_compute_forward_silu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_silu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_silu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_silu_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} +// ggml_compute_forward_leaky_relu + +static void ggml_compute_forward_leaky_relu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + float negative_slope; + memcpy(&negative_slope, dst->op_params, sizeof(float)); + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_leaky_relu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope); + } +} + +static void ggml_compute_forward_leaky_relu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_leaky_relu_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_silu_back + +static void ggml_compute_forward_silu_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * grad, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src0, grad)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_silu_backward_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1])), + (float *) ((char *) grad->data + i1*(grad->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_silu_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * grad, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_silu_back_f32(params, src0, grad, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_norm + +static void ggml_compute_forward_norm_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float)x[i00]; + } + + float mean = sum/ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + ggml_float sum2 = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + float v = x[i00] - mean; + y[i00] = v; + sum2 += (ggml_float)(v*v); + } + + float variance = sum2/ne00; + const float scale = 1.0f/sqrtf(variance + eps); + + ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void ggml_compute_forward_norm( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_norm_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_group_rms_norm + +static void ggml_compute_forward_rms_norm_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float)(x[i00] * x[i00]); + } + + const float mean = sum/ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + memcpy(y, x, ne00 * sizeof(float)); + // for (int i00 = 0; i00 < ne00; i00++) { + // y[i00] = x[i00]; + // } + + const float scale = 1.0f/sqrtf(mean + eps); + + ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void ggml_compute_forward_rms_norm( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rms_norm_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +static void ggml_compute_forward_rms_norm_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + // src1 is same shape as src0 => same indices + const int64_t i11 = i01; + const int64_t i12 = i02; + const int64_t i13 = i03; + + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); + + ggml_float sum_xx = 0.0; + ggml_float sum_xdz = 0.0; + + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum_xx += (ggml_float)(x[i00] * x[i00]); + sum_xdz += (ggml_float)(x[i00] * dz[i00]); + } + + //const float mean = (float)(sum_xx)/ne00; + const float mean_eps = (float)(sum_xx)/ne00 + eps; + const float sum_eps = (float)(sum_xx) + eps*ne00; + //const float mean_xdz = (float)(sum_xdz)/ne00; + // we could cache rms from forward pass to improve performance. + // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms. + //const float rms = sqrtf(mean_eps); + const float rrms = 1.0f / sqrtf(mean_eps); + //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3) + + { + // z = rms_norm(x) + // + // rms_norm(src0) = + // scale( + // src0, + // div( + // 1, + // sqrt( + // add( + // scale( + // sum( + // sqr( + // src0)), + // (1.0/N)), + // eps)))); + + // postorder: + // ## op args grad + // 00 param src0 grad[#00] + // 01 const 1 + // 02 sqr (#00) grad[#02] + // 03 sum (#02) grad[#03] + // 04 const 1/N + // 05 scale (#03, #04) grad[#05] + // 06 const eps + // 07 add (#05, #06) grad[#07] + // 08 sqrt (#07) grad[#08] + // 09 div (#01,#08) grad[#09] + // 10 scale (#00,#09) grad[#10] + // + // backward pass, given grad[#10] + // #10: scale + // grad[#00] += scale(grad[#10],#09) + // grad[#09] += sum(mul(grad[#10],#00)) + // #09: div + // grad[#08] += neg(mul(grad[#09], div(#09,#08))) + // #08: sqrt + // grad[#07] += mul(grad[#08], div(0.5, #08)) + // #07: add + // grad[#05] += grad[#07] + // #05: scale + // grad[#03] += scale(grad[#05],#04) + // #03: sum + // grad[#02] += repeat(grad[#03], #02) + // #02: + // grad[#00] += scale(mul(#00, grad[#02]), 2.0) + // + // substitute and simplify: + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) + // grad[#02] = repeat(grad[#03], #02) + // grad[#02] = repeat(scale(grad[#05],#04), #02) + // grad[#02] = repeat(scale(grad[#07],#04), #02) + // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02) + // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02) + // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02) + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N))) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps))) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps)) + // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps)) + // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps)) + // a = b*c + d*e + // a = b*c*f/f + d*e*f/f + // a = (b*c*f + d*e*f)*(1/f) + // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c)) + // a = (b + d*e/c)*c + // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps) + // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms + // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms + // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms + // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms + // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms + // a = (dz + x*div(-mean_xdz,mean_eps))*rrms + // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms) + // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + } + // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + // post-order: + // dx := x + // dx := scale(dx,-mean_xdz/mean_eps) + // dx := add(dx, dz) + // dx := scale(dx, rrms) + float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + ggml_vec_cpy_f32 (ne00, dx, x); + // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); + ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); + ggml_vec_acc_f32 (ne00, dx, dz); + ggml_vec_scale_f32(ne00, dx, rrms); + } + } + } +} + +static void ggml_compute_forward_rms_norm_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rms_norm_back_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_group_norm + +static void ggml_compute_forward_group_norm_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + const float eps = 1e-6f; // TODO: make this a parameter + + // TODO: optimize + + int n_channels = src0->ne[2]; + int n_groups = dst->op_params[0]; + int n_channels_per_group = (n_channels + n_groups - 1) / n_groups; + for (int i = ith; i < n_groups; i+=nth) { + int start = i * n_channels_per_group; + int end = start + n_channels_per_group; + if (end > n_channels) { + end = n_channels; + } + int step = end - start; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + ggml_float sum = 0.0; + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float)x[i00]; + } + } + } + float mean = sum / (ne00 * ne01 * step); + ggml_float sum2 = 0.0; + + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + + float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); + + for (int64_t i00 = 0; i00 < ne00; i00++) { + float v = x[i00] - mean; + y[i00] = v; + sum2 += (ggml_float)(v * v); + } + } + } + float variance = sum2 / (ne00 * ne01 * step); + const float scale = 1.0f / sqrtf(variance + eps); + + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); + ggml_vec_scale_f32(ne00, y, scale); + } + } + } + } +} + +static void ggml_compute_forward_group_norm( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_group_norm_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_mul_mat + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +// helper function to determine if it is better to use BLAS or not +// for large matrices, BLAS is faster +static bool ggml_compute_forward_mul_mat_use_blas( + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + //const int64_t ne00 = src0->ne[0]; + //const int64_t ne01 = src0->ne[1]; + + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + // NOTE: with GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float) + // all the experts for each batch element and the processing would become incredibly slow + // TODO: find the optimal values for these + if (dst->op != GGML_OP_MUL_MAT_ID && + ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && + //src0->type == GGML_TYPE_F32 && + src1->type == GGML_TYPE_F32 && + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { + + /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ + return true; + } + + return false; +} +#endif + +static void ggml_compute_forward_mul_mat( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + + const bool src1_cont = ggml_is_contiguous(src1); + + ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; + enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + +#if defined(GGML_USE_CLBLAST) + if (ggml_cl_can_mul_mat(src0, src1, dst)) { + if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { + ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); + } + return; + } +#endif + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + if (params->ith != 0) { + return; + } + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + // broadcast src0 into src1 across 2nd,3rd dimension + const int64_t i03 = i13/r3; + const int64_t i02 = i12/r2; + + const void * x = (char *) src0->data + i02*nb02 + i03*nb03; + const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); + float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); + + if (type != GGML_TYPE_F32) { + float * const wdata = params->wdata; + ggml_to_float_t const to_float = type_traits[type].to_float; + + size_t id = 0; + for (int64_t i01 = 0; i01 < ne01; ++i01) { + to_float((const char *) x + i01*nb01, wdata + id, ne00); + id += ne00; + } + + assert(id*sizeof(float) <= params->wsize); + x = wdata; + } + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne1, ne01, ne10, + 1.0f, y, ne10, + x, ne00, + 0.0f, d, ne01); + } + } + + //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif + + if (params->type == GGML_TASK_INIT) { + if (src1->type != vec_dot_type) { + char * wdata = params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + assert(params->wsize >= ne11*ne12*ne13*row_size); + assert(src1->type == GGML_TYPE_F32); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += row_size; + } + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = ne1*ne12*ne13; // src1 rows + + //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); + + // distribute the thread work across the inner or outer loop based on which one is larger + + const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + + const int64_t ith0 = ith % nth0; + const int64_t ith1 = ith / nth0; + + const int64_t dr0 = (nr0 + nth0 - 1)/nth0; + const int64_t dr1 = (nr1 + nth1 - 1)/nth1; + + const int64_t ir010 = dr0*ith0; + const int64_t ir011 = MIN(ir010 + dr0, nr0); + + const int64_t ir110 = dr1*ith1; + const int64_t ir111 = MIN(ir110 + dr1, nr1); + + //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111); + + // threads with no work simply yield (not sure if it helps) + if (ir010 >= ir011 || ir110 >= ir111) { + sched_yield(); + return; + } + + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); + + // block-tiling attempt + const int64_t blck_0 = 16; + const int64_t blck_1 = 16; + + // attempt to reduce false-sharing (does not seem to make a difference) + float tmp[16]; + + for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { + for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { + const int64_t i13 = (ir1/(ne12*ne1)); + const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1; + const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1); + + // broadcast src0 into src1 + const int64_t i03 = i13/r3; + const int64_t i02 = i12/r2; + + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; + + const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03); + + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size + : (i11*nb11 + i12*nb12 + i13*nb13)); + + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); + + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); + //} + + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col); + } + memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + } + } + } +} + +// ggml_compute_forward_mul_mat_id + +static void ggml_compute_forward_mul_mat_id( + const struct ggml_compute_params * params, + const struct ggml_tensor * ids, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[2]; // only for GGML_TENSOR_BINARY_OP_LOCALS + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + + const bool src1_cont = ggml_is_contiguous(src1); + + ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; + enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + // row groups + const int id = ggml_get_op_params_i32(dst, 0); + const int n_as = ggml_get_op_params_i32(dst, 1); + + char * wdata_src1_end = (src1->type == vec_dot_type) ? + (char *) params->wdata : + (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); + + int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11] + + #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)] + + if (params->type == GGML_TASK_INIT) { + char * wdata = params->wdata; + if (src1->type != vec_dot_type) { + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + assert(params->wsize >= ne11*ne12*ne13*row_size); + assert(src1->type == GGML_TYPE_F32); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += row_size; + } + } + } + } + + // initialize matrix_row_counts + GGML_ASSERT(wdata == wdata_src1_end); + memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); + + // group rows by src0 matrix + for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { + const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]); + + GGML_ASSERT(row_id >= 0 && row_id < n_as); + MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01; + matrix_row_counts[row_id] += 1; + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // compute each matrix multiplication in sequence + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const struct ggml_tensor * src0_cur = dst->src[cur_a + 2]; + + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = cne1*ne12*ne13; // src1 rows + + //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); + + // distribute the thread work across the inner or outer loop based on which one is larger + + const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + + const int64_t ith0 = ith % nth0; + const int64_t ith1 = ith / nth0; + + const int64_t dr0 = (nr0 + nth0 - 1)/nth0; + const int64_t dr1 = (nr1 + nth1 - 1)/nth1; + + const int64_t ir010 = dr0*ith0; + const int64_t ir011 = MIN(ir010 + dr0, nr0); + + const int64_t ir110 = dr1*ith1; + const int64_t ir111 = MIN(ir110 + dr1, nr1); + + //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111); + + // threads with no work simply yield (not sure if it helps) + if (ir010 >= ir011 || ir110 >= ir111) { + sched_yield(); + continue; + } + + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); + + // block-tiling attempt + const int64_t blck_0 = 16; + const int64_t blck_1 = 16; + + // attempt to reduce false-sharing (does not seem to make a difference) + float tmp[16]; + + for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { + for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { + const int64_t i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix + const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1; + const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1); + const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11); + + // broadcast src0 into src1 + const int64_t i03 = i13/r3; + const int64_t i02 = i12/r2; + + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; + + const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03); + + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size + : (i11*nb11 + i12*nb12 + i13*nb13)); + + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); + + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); + //} + + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col); + } + memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + } + } + } + } + + #undef MMID_MATRIX_ROW +} + +// ggml_compute_forward_out_prod + +static void ggml_compute_forward_out_prod_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + // int64_t t0 = ggml_perf_time_us(); + // UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne3 == ne13); + GGML_ASSERT(ne03 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + // GGML_ASSERT(nb0 <= nb1); + // GGML_ASSERT(nb1 <= nb2); + // GGML_ASSERT(nb2 <= nb3); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod + // TODO: #if defined(GGML_USE_CLBLAST) + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + bool use_blas = ggml_is_matrix(src0) && + ggml_is_matrix(src1) && + ggml_is_contiguous(src0) && + (ggml_is_contiguous(src1) || ggml_is_transposed(src1)); +#endif + + if (params->type == GGML_TASK_INIT) { +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst + if (use_blas) { + return; + } +#endif + ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (use_blas) { + if (params->ith != 0) { // All threads other than the first do no work. + return; + } + // Arguments to ggml_compute_forward_out_prod (expressed as major,minor) + // src0: (k,n) + // src1: (k,m) + // dst: (m,n) + // + // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f) + // Also expressed as (major,minor) + // a: (m,k): so src1 transposed + // b: (k,n): so src0 + // c: (m,n) + // + // However, if ggml_is_transposed(src1) is true, then + // src1->data already contains a transposed version, so sgemm mustn't + // transpose it further. + + int n = src0->ne[0]; + int k = src0->ne[1]; + int m = src1->ne[0]; + + int transposeA, lda; + + if (!ggml_is_transposed(src1)) { + transposeA = CblasTrans; + lda = m; + } else { + transposeA = CblasNoTrans; + lda = k; + } + + float * a = (float *) ((char *) src1->data); + float * b = (float *) ((char *) src0->data); + float * c = (float *) ((char *) dst->data); + + cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n); + + return; + } +#endif + + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i1: + // for i01: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + + // parallelize by last three dimensions + + // total rows in dst + const int64_t nr = ne1*ne2*ne3; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + // block-tiling attempt + const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32); + const int64_t blck_1 = 16; + + for (int64_t bir = ir0; bir < ir1; bir += blck_1) { + const int64_t bir1 = MIN(bir + blck_1, ir1); + for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { + const int64_t bne01 = MIN(bi01 + blck_0, ne01); + for (int64_t ir = bir; ir < bir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int64_t i02 = i2; + const int64_t i03 = i3; + + //const int64_t i10 = i1; + const int64_t i12 = i2; + const int64_t i13 = i3; + +#if GGML_VEC_MAD_UNROLL > 2 + const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL); + for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); + } + for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32(ne0, d, s0, *s1); + } +#else + for (int64_t i01 = bi01; i01 < bne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32(ne0, d, s0, *s1); + } +#endif + } + } + } + + //int64_t t1 = ggml_perf_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_out_prod_q_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + // int64_t t0 = ggml_perf_time_us(); + // UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 dim0 + GGML_ASSERT(nb00 == ggml_type_size(type)); + + // dst dim0 cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + // GGML_ASSERT(nb0 <= nb1); + // GGML_ASSERT(nb1 <= nb2); + // GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod + // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) + + if (params->type == GGML_TASK_INIT) { + ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by last three dimensions + + // total rows in dst + const int64_t nr = ne1*ne2*ne3; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i1: + // for i01: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + + float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; + + for (int64_t ir = ir0; ir < ir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int64_t i02 = i2; + const int64_t i03 = i3; + + //const int64_t i10 = i1; + const int64_t i12 = i2; + const int64_t i13 = i3; + + for (int64_t i01 = 0; i01 < ne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + dequantize_row_q(s0, wdata, ne0); + ggml_vec_mad_f32(ne0, d, wdata, *s1); + } + } + + //int64_t t1 = ggml_perf_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_out_prod( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + { + ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(false); // todo + // ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_out_prod_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_scale + +static void ggml_compute_forward_scale_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // scale factor + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const size_t nb01 = src0->nb[1]; + + const size_t nb1 = dst->nb[1]; + + for (int i1 = ir0; i1 < ir1; i1++) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + } + ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v); + } +} + +static void ggml_compute_forward_scale( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_scale_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_set + +static void ggml_compute_forward_set_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + + // view src0 and dst with these strides and data offset inbytes during set + // nb0 is implicitly element_size because src0 and dst are contiguous + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace && (params->type == GGML_TASK_INIT)) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + ggml_nbytes(dst)); + } + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src1); + const int nc = src1->ne[0]; + + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) + + // src0 and dst as viewed during set + const size_t nb0 = ggml_element_size(src0); + + const int im0 = (ne10 == 0 ? 0 : ne10-1); + const int im1 = (ne11 == 0 ? 0 : ne11-1); + const int im2 = (ne12 == 0 ? 0 : ne12-1); + const int im3 = (ne13 == 0 ? 0 : ne13-1); + + GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst)); + + GGML_ASSERT(nb10 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are viewed with shape of src1 and offset + // => same indices + const int i3 = ir/(ne12*ne11); + const int i2 = (ir - i3*ne12*ne11)/ne11; + const int i1 = (ir - i3*ne12*ne11 - i2*ne11); + + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); + } +} + +static void ggml_compute_forward_set( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_set_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_cpy + +static void ggml_compute_forward_cpy( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + ggml_compute_forward_dup(params, src0, dst); +} + +// ggml_compute_forward_cont + +static void ggml_compute_forward_cont( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + ggml_compute_forward_dup(params, src0, dst); +} + +// ggml_compute_forward_reshape + +static void ggml_compute_forward_reshape( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + // NOP + UNUSED(params); + UNUSED(src0); + UNUSED(dst); +} + +// ggml_compute_forward_view + +static void ggml_compute_forward_view( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// ggml_compute_forward_permute + +static void ggml_compute_forward_permute( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// ggml_compute_forward_transpose + +static void ggml_compute_forward_transpose( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// ggml_compute_forward_get_rows + +static void ggml_compute_forward_get_rows_q( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr); + + const enum ggml_type type = src0->type; + ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == ggml_type_size(type)); + assert(ggml_nrows(dst) == nr); + + // TODO: multi-thread + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + dequantize_row_q( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } + } + } +} + +static void ggml_compute_forward_get_rows_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(ggml_fp16_t)); + assert(ggml_nrows(dst) == nr); + + // TODO: multi-thread + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + ggml_fp16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } + } + } +} + +static void ggml_compute_forward_get_rows_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(float)); + assert(ggml_nrows(dst) == nr); + + // TODO: multi-thread + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), + (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + } + } + } +} + +static void ggml_compute_forward_get_rows( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + { + ggml_compute_forward_get_rows_q(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_get_rows_f16(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_get_rows_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } + + //static bool first = true; + //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + //if (first) { + // first = false; + //} else { + // for (int k = 0; k < dst->ne[1]; ++k) { + // for (int j = 0; j < dst->ne[0]/16; ++j) { + // for (int i = 0; i < 16; ++i) { + // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + // } + // printf("\n"); + // } + // printf("\n"); + // } + // printf("\n"); + // exit(0); + //} +} + +// ggml_compute_forward_get_rows_back + +static void ggml_compute_forward_get_rows_back_f32_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_is_contiguous(dst)); + + // ggml_compute_forward_dup_same_cont(params, opt0, dst); + + if (params->type == GGML_TASK_INIT) { + memset(dst->data, 0, ggml_nbytes(dst)); + } + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + GGML_ASSERT( dst->ne[0] == nc); + GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + for (int j = 0; j < nc; ++j) { + ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j]; + ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v); + } + } +} + +static void ggml_compute_forward_get_rows_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_is_contiguous(dst)); + + // ggml_compute_forward_dup_same_cont(params, opt0, dst); + + if (params->type == GGML_TASK_INIT) { + memset(dst->data, 0, ggml_nbytes(dst)); + } + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + GGML_ASSERT( dst->ne[0] == nc); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + ggml_vec_add_f32(nc, + (float *) ((char *) dst->data + r*dst->nb[1]), + (float *) ((char *) dst->data + r*dst->nb[1]), + (float *) ((char *) src0->data + i*src0->nb[1])); + } +} + +static void ggml_compute_forward_get_rows_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } + + //static bool first = true; + //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + //if (first) { + // first = false; + //} else { + // for (int k = 0; k < dst->ne[1]; ++k) { + // for (int j = 0; j < dst->ne[0]/16; ++j) { + // for (int i = 0; i < 16; ++i) { + // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + // } + // printf("\n"); + // } + // printf("\n"); + // } + // printf("\n"); + // exit(0); + //} +} + +// ggml_compute_forward_diag + +static void ggml_compute_forward_diag_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne00 == ne0); + GGML_ASSERT(ne00 == ne1); + GGML_ASSERT(ne01 == 1); + GGML_ASSERT(ne02 == ne2); + GGML_ASSERT(ne03 == ne3); + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + for (int i1 = 0; i1 < ne1; i1++) { + float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02); + for (int i0 = 0; i0 < i1; i0++) { + d[i0] = 0; + } + d[i1] = s[i1]; + for (int i0 = i1+1; i0 < ne0; i0++) { + d[i0] = 0; + } + } + } + } +} + +static void ggml_compute_forward_diag( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_diag_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_diag_mask_inf + +static void ggml_compute_forward_diag_mask_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst, + const float value) { + + const int ith = params->ith; + const int nth = params->nth; + + const int n_past = ((int32_t *) dst->op_params)[0]; + const bool inplace = src0->data == dst->data; + + GGML_ASSERT(n_past >= 0); + + if (!inplace && (params->type == GGML_TASK_INIT)) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + memcpy( + ((char *) dst->data), + ((char *) src0->data), + ggml_nbytes(dst)); + } + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + const int nr = src0->ne[1]; + const int nz = n/nr; + + GGML_ASSERT( dst->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int k = 0; k < nz; k++) { + for (int j = ith; j < nr; j += nth) { + for (int i = n_past; i < nc; i++) { + if (i > n_past + j) { + *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value; + } + } + } + } +} + +static void ggml_compute_forward_diag_mask_inf( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_diag_mask_f32(params, src0, dst, -INFINITY); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +static void ggml_compute_forward_diag_mask_zero( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_diag_mask_f32(params, src0, dst, 0); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_soft_max + +static void ggml_compute_forward_soft_max_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(ggml_is_contiguous(dst)); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + float scale = 1.0f; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + + // TODO: handle transposed/permuted matrices + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t ne11 = src1 ? src1->ne[1] : 1; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; + + for (int i1 = ir0; i1 < ir1; i1++) { + float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); + float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); + + // broadcast the mask across rows + float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + + ggml_vec_cpy_f32 (nc, wp, sp); + ggml_vec_scale_f32(nc, wp, scale); + if (mp) { + ggml_vec_acc_f32(nc, wp, mp); + } + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(wp[i])); + } +#endif + + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, wp); + + ggml_float sum = 0.0; + + uint16_t scvt; + for (int i = 0; i < nc; i++) { + if (wp[i] == -INFINITY) { + dp[i] = 0.0f; + } else { + // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max); + ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max); + memcpy(&scvt, &s, sizeof(scvt)); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]); + sum += (ggml_float)val; + dp[i] = val; + } + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(nc, dp, sum); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(dp[i])); + assert(!isinf(dp[i])); + } +#endif + } +} + +static void ggml_compute_forward_soft_max( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_soft_max_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_soft_max_back + +static void ggml_compute_forward_soft_max_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src1, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float *dy = (float *)((char *) src0->data + i1*src0->nb[1]); + float *y = (float *)((char *) src1->data + i1*src1->nb[1]); + float *dx = (float *)((char *) dst->data + i1*dst->nb[1]); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(dy[i])); + assert(!isnan(y[i])); + } +#endif + // Jii = yi - yi*yi + // Jij = -yi*yj + // J = diag(y)-y.T*y + // dx = J * dy + // dxk = sum_i(Jki * dyi) + // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk + // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk + // dxk = sum_i(-yk*yi * dyi) + yk*dyk + // dxk = -yk * sum_i(yi * dyi) + yk*dyk + // dxk = -yk * dot(y, dy) + yk*dyk + // dxk = yk * (- dot(y, dy) + dyk) + // dxk = yk * (dyk - dot(y, dy)) + // + // post-order: + // dot_y_dy := dot(y, dy) + // dx := dy + // dx := dx - dot_y_dy + // dx := dx * y + + // linear runtime, no additional memory + float dot_y_dy = 0; + ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy); + ggml_vec_cpy_f32 (nc, dx, dy); + ggml_vec_acc1_f32(nc, dx, -dot_y_dy); + ggml_vec_mul_f32 (nc, dx, dx, y); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(dx[i])); + assert(!isinf(dx[i])); + } +#endif + } +} + +static void ggml_compute_forward_soft_max_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_alibi + +static void ggml_compute_forward_alibi_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); + + const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 + const int64_t ne1 = src0->ne[1]; // seq_len_without_past + const int64_t ne2 = src0->ne[2]; // n_head -> this is k + //const int64_t ne3 = src0->ne[3]; // 1 -> bsz + + const int64_t n = ggml_nrows(src0); + const int64_t ne2_ne3 = n/ne1; // ne2*ne3 + + const size_t nb0 = src0->nb[0]; + const size_t nb1 = src0->nb[1]; + const size_t nb2 = src0->nb[2]; + //const int nb3 = src0->nb[3]; + + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(n_head == ne2); + + // add alibi to src0 (KQ_scaled) + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); + + for (int64_t i = 0; i < ne0; i++) { + for (int64_t j = 0; j < ne1; j++) { + for (int64_t k = 0; k < ne2_ne3; k++) { + float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); + float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); + + // TODO: k*nb2 or k*nb3 + + float m_k; + + if (k < n_heads_log2_floor) { + m_k = powf(m0, k + 1); + } else { + m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); + } + + pdst[0] = i * m_k + src[0]; + } + } + } +} + +static void ggml_compute_forward_alibi_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); + + const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 + const int ne1 = src0->ne[1]; // seq_len_without_past + const int ne2 = src0->ne[2]; // n_head -> this is k + //const int ne3 = src0->ne[3]; // 1 -> bsz + + const int n = ggml_nrows(src0); + const int ne2_ne3 = n/ne1; // ne2*ne3 + + const int nb0 = src0->nb[0]; + const int nb1 = src0->nb[1]; + const int nb2 = src0->nb[2]; + //const int nb3 = src0->nb[3]; + + GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); + //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; + GGML_ASSERT(n_head == ne2); + + // add alibi to src0 (KQ_scaled) + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); + + for (int i = 0; i < ne0; i++) { + for (int j = 0; j < ne1; j++) { + for (int k = 0; k < ne2_ne3; k++) { + ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); + float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); + + // TODO: k*nb2 or k*nb3 + + float m_k; + + if (k < n_heads_log2_floor) { + m_k = powf(m0, k + 1); + } else { + m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); + } + + // we return F32 + pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]); + } + } + } +} + +static void ggml_compute_forward_alibi( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_alibi_f16(params, src0, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_alibi_f32(params, src0, dst); + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_K: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_clamp + +static void ggml_compute_forward_clamp_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + float min; + float max; + memcpy(&min, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + for (int j = ith; j < n; j += nth) { + float * dst_ptr = (float *) ((char *) dst->data + j*nb1); + float * src0_ptr = (float *) ((char *) src0->data + j*nb01); + + for (int i = 0; i < nc; i++) { + dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min); + } + } +} + +static void ggml_compute_forward_clamp( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_clamp_f32(params, src0, dst); + } break; + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_K: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_rope + +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / MAX(0.001f, high - low); + return 1 - MIN(1, MAX(0, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + float * cos_theta, float * sin_theta +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + *cos_theta = cosf(theta) * mscale; + *sin_theta = sinf(theta) * mscale; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) { + return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); +} + +void ggml_rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base))); + dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base))); +} + +static void ggml_compute_forward_rope_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst, + const bool forward) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + + // these two only relevant for xPos RoPE: + float xpos_base; + bool xpos_down; + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float)); + memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool)); + + GGML_TENSOR_UNARY_OP_LOCALS + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + GGML_ASSERT(nb00 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(dst); + + GGML_ASSERT(n_dims <= ne0); + GGML_ASSERT(n_dims % 2 == 0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // row index used to determine which thread to use + int ir = 0; + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float inv_ndims = -1.f/n_dims; + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + + const bool is_neox = mode & 2; + const bool is_glm = mode & 4; + + // backward process uses inverse rotation by cos and sin. + // cos and sin build a rotation matrix, where the inverse is the transpose. + // this essentially just switches the sign of sin. + const float sin_sign = forward ? 1.0f : -1.0f; + + const int32_t * pos = (const int32_t *) src1->data; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (ir++ < ir0) continue; + if (ir > ir1) break; + + float theta_base = (float)p; + + if (is_glm) { + theta_base = MIN(p, n_ctx - 2); + float block_theta = MAX(p - (n_ctx - 2), 0); + for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base) * sin_sign; + const float cos_block_theta = cosf(block_theta); + const float sin_block_theta = sinf(block_theta) * sin_sign; + + theta_base *= theta_scale; + block_theta *= theta_scale; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + const float x2 = src[n_dims]; + const float x3 = src[n_dims/2*3]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta; + dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta; + } + } else if (!is_neox) { + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta + ); + sin_theta *= sin_sign; + + // zeta scaling for xPos only: + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; + if (xpos_down) zeta = 1.0f / zeta; + + theta_base *= theta_scale; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[1]; + + dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta; + dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta; + } + } else { + // TODO: this might be wrong for ne0 != n_dims - need double check + // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + theta_base *= freq_scale; + for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { + for (int64_t ic = 0; ic < n_dims; ic += 2) { + // simplified from `(ib * n_dims + ic) * inv_ndims` + float cur_rot = inv_ndims * ic - ib; + + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + &cos_theta, &sin_theta + ); + sin_theta *= sin_sign; + + theta_base *= theta_scale; + + const int64_t i0 = ib*n_dims + ic/2; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } + } + } + } + } + } +} + +static void ggml_compute_forward_rope_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst, + const bool forward) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(dst); + + GGML_ASSERT(n_dims <= ne0); + GGML_ASSERT(n_dims % 2 == 0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // row index used to determine which thread to use + int ir = 0; + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float inv_ndims = -1.f/n_dims; + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + + const bool is_neox = mode & 2; + const bool is_glm = mode & 4; + + // backward process uses inverse rotation by cos and sin. + // cos and sin build a rotation matrix, where the inverse is the transpose. + // this essentially just switches the sign of sin. + const float sin_sign = forward ? 1.0f : -1.0f; + + const int32_t * pos = (const int32_t *) src1->data; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (ir++ < ir0) continue; + if (ir > ir1) break; + + float theta_base = (float)p; + + if (is_glm) { + theta_base = MIN(p, n_ctx - 2); + float block_theta = MAX(p - (n_ctx - 2), 0); + for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base) * sin_sign; + const float cos_block_theta = cosf(block_theta); + const float sin_block_theta = sinf(block_theta) * sin_sign; + + theta_base *= theta_scale; + block_theta *= theta_scale; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); + const float x2 = GGML_FP16_TO_FP32(src[n_dims]); + const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]); + + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta); + dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); + } + } else if (!is_neox) { + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta + ); + sin_theta *= sin_sign; + + theta_base *= theta_scale; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[1]); + + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } else { + // TODO: this might be wrong for ne0 != n_dims - need double check + // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + theta_base *= freq_scale; + for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { + for (int64_t ic = 0; ic < n_dims; ic += 2) { + // simplified from `(ib * n_dims + ic) * inv_ndims` + float cur_rot = inv_ndims * ic - ib; + + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + &cos_theta, &sin_theta + ); + sin_theta *= sin_sign; + + theta_base *= theta_scale; + + const int64_t i0 = ib*n_dims + ic/2; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); + + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } + } + } + } + } +} + +static void ggml_compute_forward_rope( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_rope_f16(params, src0, src1, dst, true); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_rope_f32(params, src0, src1, dst, true); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_rope_back + +static void ggml_compute_forward_rope_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_rope_f16(params, src0, src1, dst, false); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_rope_f32(params, src0, src1, dst, false); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_conv_transpose_1d + +static void ggml_compute_forward_conv_transpose_1d_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00*ne01*ne02; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + memset(params->wdata, 0, params->wsize); + + // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); + ggml_fp16_t * dst_data = wdata + i01*ne00*ne02; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ne02 + i02] = src[i00]; + } + } + } + } + + // permute source data (src1) from (L x Cin) to (Cin x L) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; + ggml_fp16_t * dst_data = wdata; + + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]); + } + } + } + + // need to zero dst since we are accumulating into it + memset(dst->data, 0, ggml_nbytes(dst)); + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + + // total rows in dst + const int nr = ne1; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + ggml_fp16_t * const wdata_src = wdata + nk; + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + ggml_vec_dot_f16(ne02, &v, + (ggml_fp16_t *) wdata_src + i1n, + (ggml_fp16_t *) wdata_kernel + i00*ne02); + dst_data[i10*s0 + i00] += v; + } + } + } +} + +static void ggml_compute_forward_conv_transpose_1d_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00*ne01*ne02; + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) + { + float * const wdata = (float *) params->wdata + 0; + + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); + float * dst_data = wdata + i01*ne00*ne02; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ne02 + i02] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + float * const wdata = (float *) params->wdata + nk; + float * dst_data = wdata; + + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[i10*ne11 + i11] = src[i10]; + } + } + } + + // need to zero dst since we are accumulating into it + memset(dst->data, 0, ggml_nbytes(dst)); + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + + // total rows in dst + const int nr = ne1; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * const wdata = (float *) params->wdata + 0; + float * const wdata_src = wdata + nk; + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + float * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + ggml_vec_dot_f32(ne02, &v, + wdata_src + i1n, + wdata_kernel + i00*ne02); + dst_data[i10*s0 + i00] += v; + } + } + } +} + +static void ggml_compute_forward_conv_transpose_1d( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// src0: kernel [OC, IC, KH, KW] +// src1: image [N, IC, IH, IW] +// dst: result [N, OH, OW, IC*KH*KW] +static void ggml_compute_forward_im2col_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = is_2D ? ne13 : ne12; + const int64_t IC = is_2D ? ne12 : ne11; + const int64_t IH = is_2D ? ne11 : 1; + const int64_t IW = ne10; + + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + + const int64_t OH = is_2D ? ne2 : 1; + const int64_t OW = ne1; + + int ofs0 = is_2D ? nb13 : nb12; + int ofs1 = is_2D ? nb12 : nb11; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { // 1 + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + + // micro kernel + ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] + + for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; + } else { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); + } + } + } + } + } + } + } + } +} + +static void ggml_compute_forward_im2col( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_im2col_f16(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(false); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_conv_transpose_2d + +static void ggml_compute_forward_conv_transpose_2d( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00*ne01*ne02*ne03; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + memset(params->wdata, 0, params->wsize); + + // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02); + ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03; + for (int64_t i01 = 0; i01 < ne01; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00]; + } + } + } + } + } + + // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; + for (int i12 = 0; i12 < ne12; i12++) { + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11); + ggml_fp16_t * dst_data = wdata + i11*ne10*ne12; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]); + } + } + } + } + + memset(dst->data, 0, ggml_nbytes(dst)); + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + const int32_t stride = ggml_get_op_params_i32(dst, 0); + + // total patches in dst + const int np = ne2; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + const int ip0 = dp*ith; + const int ip1 = MIN(ip0 + dp, np); + + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + ggml_fp16_t * const wdata_src = wdata + nk; + + for (int i2 = ip0; i2 < ip1; i2++) { // Cout + float * dst_data = (float *)((char *) dst->data + i2*nb2); + ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; + for (int i11 = 0; i11 < ne11; i11++) { + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i11*ne10*ne12 + i10*ne12; + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + ggml_vec_dot_f16(ne03, &v, + wdata_src + i1n, + wdata_kernel + i01*ne00*ne03 + i00*ne03); + dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v; + } + } + } + } + } +} + +// ggml_compute_forward_pool_1d_sk_p0 + +static void ggml_compute_forward_pool_1d_sk_p0( + const struct ggml_compute_params * params, + const enum ggml_op_pool op, + const struct ggml_tensor * src, + const int k, + struct ggml_tensor * dst) { + assert(src->type == GGML_TYPE_F32); + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const char * cdata = (const char *)src->data; + const char * const data_end = cdata + ggml_nbytes(src); + float * drow = (float *)dst->data; + + const int64_t rs = dst->ne[0]; + + while (cdata < data_end) { + const float * const srow = (const float *)cdata; + + int j = 0; + + for (int64_t i = 0; i < rs; ++i) { + switch (op) { + case GGML_OP_POOL_AVG: drow[i] = 0; break; + case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break; + case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + } + for (int ki = 0; ki < k; ++ki) { + switch (op) { + case GGML_OP_POOL_AVG: drow[i] += srow[j]; break; + case GGML_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j]; break; + case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + } + ++j; + } + switch (op) { + case GGML_OP_POOL_AVG: drow[i] /= k; break; + case GGML_OP_POOL_MAX: break; + case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + } + } + + cdata += src->nb[1]; + drow += rs; + } +} + +// ggml_compute_forward_pool_1d + +static void ggml_compute_forward_pool_1d( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + const int32_t * opts = (const int32_t *)dst->op_params; + enum ggml_op_pool op = opts[0]; + const int k0 = opts[1]; + const int s0 = opts[2]; + const int p0 = opts[3]; + GGML_ASSERT(p0 == 0); // padding not supported + GGML_ASSERT(k0 == s0); // only s = k supported + + ggml_compute_forward_pool_1d_sk_p0(params, op, src0, k0, dst); +} + +// ggml_compute_forward_pool_2d + +static void ggml_compute_forward_pool_2d( + const struct ggml_compute_params * params, + const struct ggml_tensor * src, + struct ggml_tensor * dst) { + assert(src->type == GGML_TYPE_F32); + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int32_t * opts = (const int32_t *)dst->op_params; + enum ggml_op_pool op = opts[0]; + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; + const char * cdata = (const char*)src->data; + const char * const data_end = cdata + ggml_nbytes(src); + + const int64_t px = dst->ne[0]; + const int64_t py = dst->ne[1]; + const int64_t pa = px * py; + + float * dplane = (float *)dst->data; + + const int ka = k0 * k1; + const int offset0 = -p0; + const int offset1 = -p1; + + while (cdata < data_end) { + for (int oy = 0; oy < py; ++oy) { + float * const drow = dplane + oy * px; + for (int ox = 0; ox < px; ++ox) { + float * const out = drow + ox; + switch (op) { + case GGML_OP_POOL_AVG: *out = 0; break; + case GGML_OP_POOL_MAX: *out = -FLT_MAX; break; + case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + } + + const int ix = offset0 + ox * s0; + const int iy = offset1 + oy * s1; + + for (int ky = 0; ky < k1; ++ky) { + if (iy + ky < 0 || iy + ky >= src->ne[1]) continue; + const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky)); + for (int kx = 0; kx < k0; ++kx) { + int j = ix + kx; + if (j < 0 || j >= src->ne[0]) continue; + switch (op) { + case GGML_OP_POOL_AVG: *out += srow[j]; break; + case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break; + case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + } + } + } + switch (op) { + case GGML_OP_POOL_AVG: *out /= ka; break; + case GGML_OP_POOL_MAX: break; + case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + } + } + } + + cdata += src->nb[2]; + dplane += pa; + } +} + +// ggml_compute_forward_upscale + +static void ggml_compute_forward_upscale_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + const int scale_factor = dst->op_params[0]; + + // TODO: optimize + + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const int64_t i01 = i1 / scale_factor; + for (int64_t i0 = 0; i0 < ne0; i0++) { + const int64_t i00 = i0 / scale_factor; + + const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + + *y = *x; + } + } + } + } +} + +static void ggml_compute_forward_upscale( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_upscale_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_pad + +static void ggml_compute_forward_pad_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT( dst->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float * dst_ptr = (float *) dst->data; + + // TODO: optimize + + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = ith; i1 < ne1; i1 += nth) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + for (int64_t i3 = 0; i3 < ne3; ++i3) { + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + + const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + dst_ptr[dst_idx] = *src_ptr; + } else { + dst_ptr[dst_idx] = 0; + } + } + } + } + } +} + +static void ggml_compute_forward_pad( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_pad_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_argsort + +static void ggml_compute_forward_argsort_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0); + + for (int64_t i = ith; i < nr; i += nth) { + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + const float * src_data = (float *)((char *) src0->data + i*nb01); + + for (int64_t j = 0; j < ne0; j++) { + dst_data[j] = j; + } + + // C doesn't have a functional sort, so we do a bubble sort instead + for (int64_t j = 0; j < ne0; j++) { + for (int64_t k = j + 1; k < ne0; k++) { + if ((order == GGML_SORT_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || + (order == GGML_SORT_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { + int32_t tmp = dst_data[j]; + dst_data[j] = dst_data[k]; + dst_data[k] = tmp; + } + } + } + } +} + +static void ggml_compute_forward_argsort( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_argsort_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_flash_attn + +static void ggml_compute_forward_flash_attn_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(float)); + GGML_ASSERT(nbk0 == sizeof(float)); + GGML_ASSERT(nbv0 == sizeof(float)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0f/sqrtf(D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + const int64_t masked_begin = masked ? (P + iq1 + 1) : M; + for (int64_t ic = 0; ic < masked_begin; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2 % nek2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f32(neq0, + S + i1, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + + // scale + ggml_vec_scale_f32(masked_begin, S, scale); + + for (int64_t i = masked_begin; i < M; i++) { + S[i] = -INFINITY; + } + + // softmax + // exclude known -INF S[..] values from max and loop + // dont forget to set their SW values to zero + { + float max = -INFINITY; + ggml_vec_max_f32(masked_begin, &max, S); + + ggml_float sum = 0.0; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt); + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + if (i >= masked_begin) { + break; + } + float * SS = S + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (i + j >= masked_begin) { + break; + } else if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { +#ifndef GGML_FLASH_ATTN_EXP_FP16 + const float val = expf(SS[j] - max); +#else + ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); +#endif + sump[j] += (ggml_float)val; + SS[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(masked_begin, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < masked_begin; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + for (int64_t ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f32(masked_begin, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + S); + } + } +} + +static void ggml_compute_forward_flash_attn_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0f/sqrtf(D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { + for (int64_t ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2 % nek2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16(neq0, + S + i1, + (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } else { + for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2 % nek2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16_unroll(neq0, nbk1, + S + i1, + ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int64_t i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. + // dont forget to set their S values to zero + { + float max = -INFINITY; + ggml_vec_max_f32(M, &max, S); + + ggml_float sum = 0.0; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + float * SS = S + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { + ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); + sump[j] += (ggml_float)val; + SS[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < M; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); + + for (int64_t i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). + if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { + for (int64_t ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16(nev0, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + S16); + } + } else { + for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16_unroll(nev0, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + S16); + } + } + } +} + +static void ggml_compute_forward_flash_attn( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + switch (q->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_attn_f16(params, q, k, v, masked, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_flash_ff + +static void ggml_compute_forward_flash_ff_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, // F16 + const struct ggml_tensor * b0, // F16 fc_w + const struct ggml_tensor * b1, // F32 fc_b + const struct ggml_tensor * c0, // F16 proj_w + const struct ggml_tensor * c1, // F32 proj_b + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, nea, a, ne) + GGML_TENSOR_LOCALS(size_t, nba, a, nb) + GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne) + GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb) + GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne) + GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb) + GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne) + GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb) + GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne) + GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = nea0; + //const int64_t N = nea1; + const int64_t M = neb01; + + GGML_ASSERT(ne0 == nea0); + GGML_ASSERT(ne1 == nea1); + GGML_ASSERT(ne2 == nea2); + + GGML_ASSERT(nba0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbb10 == sizeof(float)); + GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbc10 == sizeof(float)); + + GGML_ASSERT(neb00 == D); + GGML_ASSERT(neb01 == M); + GGML_ASSERT(neb10 == M); + GGML_ASSERT(neb11 == 1); + + GGML_ASSERT(nec00 == M); + GGML_ASSERT(nec01 == D); + GGML_ASSERT(nec10 == D); + GGML_ASSERT(nec11 == 1); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by a rows using ggml_vec_dot_f32 + + // total rows in a + const int nr = nea1*nea2*nea3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // a indices + const int ia3 = ir/(nea2*nea1); + const int ia2 = (ir - ia3*nea2*nea1)/nea1; + const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1); + + float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); + + for (int64_t ic = 0; ic < neb01; ++ic) { + // b0 indices + const int ib03 = ia3; + const int ib02 = ia2; + const int ib01 = ic; + + // S indices + const int i1 = ib01; + + ggml_vec_dot_f16(nea0, + S + i1, + (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), + (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3))); + } + + ggml_vec_add_f32(neb01, S, S, (float *) b1->data); + //ggml_vec_gelu_f32(neb01, S, S); + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); + + for (int64_t i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + ggml_vec_gelu_f16(neb01, S16, S16); + + { + // dst indices + const int i1 = ia1; + const int i2 = ia2; + const int i3 = ia3; + + for (int64_t ic = 0; ic < nec01; ++ic) { + + ggml_vec_dot_f16(neb01, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), + S16); + } + + ggml_vec_add_f32(nec01, + (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), + (float *) c1->data); + } + } +} + +static void ggml_compute_forward_flash_ff( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + const struct ggml_tensor * b0, + const struct ggml_tensor * b1, + const struct ggml_tensor * c0, + const struct ggml_tensor * c1, + struct ggml_tensor * dst) { + switch (b0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_ff_f16(params, a, b0, b1, c0, c1, dst); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(false); // TODO + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_flash_attn_back + +static void ggml_compute_forward_flash_attn_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * d, + const bool masked, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ned, d, ne) + GGML_TENSOR_LOCALS(size_t, nbd, d, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + const int mxDM = MAX(D, Mup); + + // GGML_ASSERT(ne0 == D); + // GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(float)); + GGML_ASSERT(nbk0 == sizeof(float)); + GGML_ASSERT(nbv0 == sizeof(float)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + GGML_ASSERT(ned0 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + GGML_ASSERT(ned1 == N); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + if (ith == 0) { + memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); + } + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + const int64_t elem_q = ggml_nelements(q); + const int64_t elem_k = ggml_nelements(k); + + enum ggml_type result_type = dst->type; + GGML_ASSERT(ggml_blck_size(result_type) == 1); + const size_t tsize = ggml_type_size(result_type); + + const size_t offs_q = 0; + const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); + const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); + + void * grad_q = (char *) dst->data; + void * grad_k = (char *) dst->data + offs_k; + void * grad_v = (char *) dst->data + offs_v; + + const size_t nbgq1 = nb0*neq0; + const size_t nbgq2 = nb0*neq0*neq1; + const size_t nbgq3 = nb0*neq0*neq1*neq2; + + const size_t nbgk1 = nb0*nek0; + const size_t nbgk2 = nb0*nek0*nek1; + const size_t nbgk3 = nb0*nek0*nek1*neq2; + + const size_t nbgv1 = nb0*nev0; + const size_t nbgv2 = nb0*nev0*nev1; + const size_t nbgv3 = nb0*nev0*nev1*neq2; + + // parallelize by k rows using ggml_vec_dot_f32 + + // total rows in k + const int nr = nek2*nek3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0f/sqrtf(D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + // how often k2 (and v2) is repeated in q2 + int nrep = neq2/nek2; + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int ik3 = ir/(nek2); + const int ik2 = ir - ik3*nek2; + + const int iq3 = ik3; + const int id3 = ik3; + const int iv3 = ik3; + const int iv2 = ik2; + + for (int irep = 0; irep < nrep; ++irep) { + const int iq2 = ik2 + irep*nek2; + const int id2 = iq2; + + // (ik2 + irep*nek2) % nek2 == ik2 + for (int iq1 = 0; iq1 < neq1; ++iq1) { + const int id1 = iq1; + + // not sure about CACHE_LINE_SIZE_F32.. + // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset? + float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32); + float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + const int64_t masked_begin = masked ? (P + iq1 + 1) : M; + for (int64_t ic = 0; ic < masked_begin; ++ic) { + // k indices + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f32(neq0, + S + i1, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + + // scale + ggml_vec_scale_f32(masked_begin, S, scale); + + for (int64_t i = masked_begin; i < M; i++) { + S[i] = -INFINITY; + } + + // softmax + // exclude known -INF S[..] values from max and loop + // dont forget to set their SM values to zero + { + float max = -INFINITY; + ggml_vec_max_f32(masked_begin, &max, S); + + ggml_float sum = 0.0; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(SM, 1, &max, SM, 1, Mup); + vvexpf(SM, SM, &Mup); + ggml_vec_sum_f32(Mup, &sum, SM); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt); + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + if (i >= masked_begin) { + break; + } + float * SR = S + i; + float * SW = SM + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (i + j >= masked_begin) { + break; + } else if (SR[j] == -INFINITY) { + SW[j] = 0.0f; + } else { +#ifndef GGML_FLASH_ATTN_EXP_FP16 + const float val = expf(SR[j] - max); +#else + ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); +#endif + sump[j] += (ggml_float)val; + SW[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(masked_begin, SM, sum); + + } + + // step-by-step explanation + { + // forward-process shape grads from backward process + // parallel_for ik2,ik3: + // for irep: + // iq2 = ik2 + irep*nek2 + // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur] + // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] + // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur] + // for iq1: + // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur + // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur + // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 + // S0 = -Inf [D,1,1,1] + // ~S1[i] = dot(kcur[:D,i], qcur) + // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale + // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) + // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur + // ~S5[i] = dot(vcur[:,i], S4) + // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3] + // ~dst[i,iq1,iq2,iq3] = S5[i] ^ + // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3] + // dst backward-/ grad[dst] = d + // + // output gradients with their dependencies: + // + // grad[kcur] = grad[S1].T @ qcur + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S4] = grad[S5] @ vcur + // grad[S4] = d[:D,id1,id2,id3] @ vcur + // grad[qcur] = grad[S1] @ kcur + // grad[vcur] = grad[S5].T @ S4 + // grad[vcur] = d[:D,id1,id2,id3].T @ S4 + // + // in post-order: + // + // S1 = qcur @ kcur.T + // S2 = S1 * scale + // S3 = diag_mask_inf(S2, P) + // S4 = softmax(S3) + // grad[S4] = d[:D,id1,id2,id3] @ vcur + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[qcur] = grad[S1] @ kcur + // grad[kcur] = grad[S1].T @ qcur + // grad[vcur] = d[:D,id1,id2,id3].T @ S4 + // + // using less variables (SM=S4): + // + // S = diag_mask_inf(qcur @ kcur.T * scale, P) + // SM = softmax(S) + // S = d[:D,iq1,iq2,iq3] @ vcur + // dot_SM_gradSM = dot(SM, S) + // S = SM * (S - dot(SM, S)) + // S = diag_mask_zero(S, P) * scale + // + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[k][:D,:M,ik2,ik3] += S.T @ qcur + // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM + } + + // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] + // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] + // for ic: + // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3] + // exclude known future zero S[..] values from operation + ggml_vec_set_f32(masked_begin, S, 0); + for (int64_t ic = 0; ic < D; ++ic) { + ggml_vec_mad_f32(masked_begin, + S, + (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); + } + + // S = SM * (S - dot(SM, S)) + float dot_SM_gradSM = 0; + ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S); + ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); + ggml_vec_mul_f32 (masked_begin, S, S, SM); + + // S = diag_mask_zero(S, P) * scale + // already done by above ggml_vec_set_f32 + + // exclude known zero S[..] values from operation + ggml_vec_scale_f32(masked_begin, S, scale); + + // S shape [M,1] + // SM shape [M,1] + // kcur shape [D,M] + // qcur shape [D,1] + // vcur shape [M,D] + + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] + // for ic: + // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3] + // exclude known zero S[..] values from loop + for (int64_t ic = 0; ic < masked_begin; ++ic) { + ggml_vec_mad_f32(D, + (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)), + (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)), + S[ic]); + } + + // grad[k][:D,:M,iq2,iq3] += S.T @ qcur + // for ic: + // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] + // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] + // exclude known zero S[..] values from loop + for (int64_t ic = 0; ic < masked_begin; ++ic) { + ggml_vec_mad_f32(D, + (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), + S[ic]); + } + + // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM + // for ic: + // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M] + // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M] + // exclude known zero SM[..] values from mad + for (int64_t ic = 0; ic < D; ++ic) { + ggml_vec_mad_f32(masked_begin, + (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)), + SM, + *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); + } + } + } + } +} + +static void ggml_compute_forward_flash_attn_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * d, + const bool masked, + struct ggml_tensor * dst) { + switch (q->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_win_part + +static void ggml_compute_forward_win_part_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + + const int32_t nep0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t nep1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t w = ((const int32_t *)(dst->op_params))[2]; + + assert(ne00 == ne0); + assert(ne3 == nep0*nep1); + + // TODO: optimize / multi-thread + for (int py = 0; py < nep1; ++py) { + for (int px = 0; px < nep0; ++px) { + const int64_t i3 = py*nep0 + px; + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i02 = py*w + i2; + const int64_t i01 = px*w + i1; + const int64_t i00 = i0; + + const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0; + const int64_t j = i02*ne01*ne00 + i01*ne00 + i00; + + if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { + ((float *) dst->data)[i] = 0.0f; + } else { + ((float *) dst->data)[i] = ((float *) src0->data)[j]; + } + } + } + } + } + } +} + +static void ggml_compute_forward_win_part( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_win_part_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_win_unpart + +static void ggml_compute_forward_win_unpart_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + + const int32_t w = ((const int32_t *)(dst->op_params))[0]; + + // padding + const int px = (w - ne1%w)%w; + //const int py = (w - ne2%w)%w; + + const int npx = (px + ne1)/w; + //const int npy = (py + ne2)/w; + + assert(ne0 == ne00); + + // TODO: optimize / multi-thread + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int ip2 = i2/w; + const int ip1 = i1/w; + + const int64_t i02 = i2%w; + const int64_t i01 = i1%w; + const int64_t i00 = i0; + + const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00; + const int64_t j = i2*ne1*ne0 + i1*ne0 + i0; + + ((float *) dst->data)[j] = ((float *) src0->data)[i]; + } + } + } +} + +static void ggml_compute_forward_win_unpart( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_win_unpart_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +//gmml_compute_forward_unary + +static void ggml_compute_forward_unary( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + const enum ggml_unary_op op = ggml_get_unary_op(dst); + + switch (op) { + case GGML_UNARY_OP_ABS: + { + ggml_compute_forward_abs(params, src0, dst); + } break; + case GGML_UNARY_OP_SGN: + { + ggml_compute_forward_sgn(params, src0, dst); + } break; + case GGML_UNARY_OP_NEG: + { + ggml_compute_forward_neg(params, src0, dst); + } break; + case GGML_UNARY_OP_STEP: + { + ggml_compute_forward_step(params, src0, dst); + } break; + case GGML_UNARY_OP_TANH: + { + ggml_compute_forward_tanh(params, src0, dst); + } break; + case GGML_UNARY_OP_ELU: + { + ggml_compute_forward_elu(params, src0, dst); + } break; + case GGML_UNARY_OP_RELU: + { + ggml_compute_forward_relu(params, src0, dst); + } break; + case GGML_UNARY_OP_GELU: + { + ggml_compute_forward_gelu(params, src0, dst); + } break; + case GGML_UNARY_OP_GELU_QUICK: + { + ggml_compute_forward_gelu_quick(params, src0, dst); + } break; + case GGML_UNARY_OP_SILU: + { + ggml_compute_forward_silu(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_get_rel_pos + +static void ggml_compute_forward_get_rel_pos_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 + + GGML_TENSOR_UNARY_OP_LOCALS + + const int64_t w = ne1; + + ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data; + ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data; + + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + const int64_t pos = (w - i1 - 1) + i2; + for (int64_t i0 = 0; i0 < ne0; ++i0) { + dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; + } + } + } +} + +static void ggml_compute_forward_get_rel_pos( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_get_rel_pos_f16(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_add_rel_pos + +static void ggml_compute_forward_add_rel_pos_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * src2, + struct ggml_tensor * dst) { + + const bool inplace = (bool) ((int32_t *) dst->op_params)[0]; + if (!inplace && params->type == GGML_TASK_INIT) { + memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst)); + return; + } + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359 + + float * src1_data = (float *) src1->data; + float * src2_data = (float *) src2->data; + float * dst_data = (float *) dst->data; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const int ith = params->ith; + const int nth = params->nth; + + // total patches in dst + const int np = ne13; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + const int ip0 = dp*ith; + const int ip1 = MIN(ip0 + dp, np); + + for (int64_t i13 = ip0; i13 < ip1; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10; + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t jp0 = jp1 + i10; + const float src1_e = src1_data[jp0]; + const float src2_e = src2_data[jp0]; + + const int64_t jdh = jp0 * ne10; + const int64_t jdw = jdh - (ne10 - 1) * i10; + + for (int64_t j = 0; j < ne10; ++j) { + dst_data[jdh + j ] += src2_e; + dst_data[jdw + j*ne10] += src1_e; + } + } + } + } + } +} + +static void ggml_compute_forward_add_rel_pos( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * src2, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_add_rel_pos_f32(params, src0, src1, src2, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_map_unary + +static void ggml_compute_forward_map_unary_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst, + const ggml_unary_op_f32_t fun) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + fun(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_map_unary( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst, + const ggml_unary_op_f32_t fun) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_map_unary_f32(params, src0, dst, fun); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_map_binary + +static void ggml_compute_forward_map_binary_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst, + const ggml_binary_op_f32_t fun) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + assert(src1->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + fun(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1]))); + } +} + +static void ggml_compute_forward_map_binary( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst, + const ggml_binary_op_f32_t fun) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_map_custom1 + +static void ggml_compute_forward_map_custom1_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + struct ggml_tensor * dst, + const ggml_custom1_op_f32_t fun) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + fun(dst, a); +} + +// ggml_compute_forward_map_custom2 + +static void ggml_compute_forward_map_custom2_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + const struct ggml_tensor * b, + struct ggml_tensor * dst, + const ggml_custom2_op_f32_t fun) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + fun(dst, a, b); +} + +// ggml_compute_forward_map_custom3 + +static void ggml_compute_forward_map_custom3_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + const struct ggml_tensor * b, + const struct ggml_tensor * c, + struct ggml_tensor * dst, + const ggml_custom3_op_f32_t fun) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + fun(dst, a, b, c); +} + +// ggml_compute_forward_map_custom1 + +static void ggml_compute_forward_map_custom1( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) dst->op_params; + + p->fun(dst, a, params->ith, params->nth, p->userdata); +} + +// ggml_compute_forward_map_custom2 + +static void ggml_compute_forward_map_custom2( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + const struct ggml_tensor * b, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) dst->op_params; + + p->fun(dst, a, b, params->ith, params->nth, p->userdata); +} + +// ggml_compute_forward_map_custom3 + +static void ggml_compute_forward_map_custom3( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + const struct ggml_tensor * b, + const struct ggml_tensor * c, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) dst->op_params; + + p->fun(dst, a, b, c, params->ith, params->nth, p->userdata); +} + +// ggml_compute_forward_cross_entropy_loss + +static void ggml_compute_forward_cross_entropy_loss_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_scalar(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + + const int ith = params->ith; + const int nth = params->nth; + + float * sums = (float *) params->wdata; + + // TODO: handle transposed/permuted matrices + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc)); + + if (params->type == GGML_TASK_INIT) { + if (ith == 0) { + memset(sums, 0, sizeof(float) * (nth + nth * nc)); + } + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + if (ith == 0) { + float * dp = (float *) dst->data; + ggml_vec_sum_f32(nth, dp, sums); + dp[0] *= -1.0f / (float) nr; + } + return; + } + + const double eps = 1e-9; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); + float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); + float * st = ((float *) params->wdata) + nth + ith*nc; + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(s0[i])); + assert(!isnan(s1[i])); + } +#endif + // soft_max + ggml_float sum = 0.0; + { + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, s0); + + uint16_t scvt; UNUSED(scvt); + for (int i = 0; i < nc; i++) { + if (s0[i] == -INFINITY) { + st[i] = 0.0f; + } else { +#ifndef GGML_CROSS_ENTROPY_EXP_FP16 + const float s = s0[i] - max; + const float val = expf(s); +#else + ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max); + memcpy(&scvt, &s, sizeof(scvt)); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]); +#endif + sum += (ggml_float)val; + st[i] = val; + } + } + + assert(sum > 0.0); + // sum = 1.0/sum; + } + // avoid log(0) by rescaling from [0..1] to [eps..1] + sum = (1.0 - eps) / sum; + ggml_vec_scale_f32(nc, st, sum); + ggml_vec_add1_f32(nc, st, st, eps); + ggml_vec_log_f32(nc, st, st); + ggml_vec_mul_f32(nc, st, st, s1); + + float st_sum = 0; + ggml_vec_sum_f32(nc, &st_sum, st); + sums[ith] += st_sum; + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(st[i])); + assert(!isinf(st[i])); + } +#endif + } + +} + +static void ggml_compute_forward_cross_entropy_loss( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_cross_entropy_loss_back + +static void ggml_compute_forward_cross_entropy_loss_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(opt0)); + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + const int64_t ith = params->ith; + const int64_t nth = params->nth; + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const double eps = 1e-9; + + // TODO: handle transposed/permuted matrices + const int64_t nc = src0->ne[0]; + const int64_t nr = ggml_nrows(src0); + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + float * d = (float *) opt0->data; + + for (int64_t i1 = ir0; i1 < ir1; i1++) { + float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); + float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); + float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(s0[i])); + assert(!isnan(s1[i])); + } +#endif + + // soft_max + ggml_float sum = 0.0; + { + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, s0); + + uint16_t scvt; UNUSED(scvt); + for (int i = 0; i < nc; i++) { + if (s0[i] == -INFINITY) { + ds0[i] = 0.0f; + } else { +#ifndef GGML_CROSS_ENTROPY_EXP_FP16 + const float s = s0[i] - max; + const float val = expf(s); +#else + ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max); + memcpy(&scvt, &s, sizeof(scvt)); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]); +#endif + sum += (ggml_float)val; + ds0[i] = val; + } + } + + assert(sum > 0.0); + sum = (1.0 - eps)/sum; + } + + // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr + ggml_vec_scale_f32(nc, ds0, sum); + ggml_vec_add1_f32(nc, ds0, ds0, eps); + ggml_vec_sub_f32(nc, ds0, ds0, s1); + ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(ds0[i])); + assert(!isinf(ds0[i])); + } +#endif + } +} + +static void ggml_compute_forward_cross_entropy_loss_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +///////////////////////////////// + +static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { + GGML_ASSERT(params); + + if (tensor->op == GGML_OP_NONE) { + return; + } + +#ifdef GGML_USE_CUBLAS + bool skip_cpu = ggml_cuda_compute_forward(params, tensor); + if (skip_cpu) { + return; + } + GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_CPU); + GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_CPU); +#endif // GGML_USE_CUBLAS + + switch (tensor->op) { + case GGML_OP_DUP: + { + ggml_compute_forward_dup(params, tensor->src[0], tensor); + } break; + case GGML_OP_ADD: + { + ggml_compute_forward_add(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_ADD1: + { + ggml_compute_forward_add1(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_ACC: + { + ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_SUB: + { + ggml_compute_forward_sub(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_MUL: + { + ggml_compute_forward_mul(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_DIV: + { + ggml_compute_forward_div(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_SQR: + { + ggml_compute_forward_sqr(params, tensor->src[0], tensor); + } break; + case GGML_OP_SQRT: + { + ggml_compute_forward_sqrt(params, tensor->src[0], tensor); + } break; + case GGML_OP_LOG: + { + ggml_compute_forward_log(params, tensor->src[0], tensor); + } break; + case GGML_OP_SUM: + { + ggml_compute_forward_sum(params, tensor->src[0], tensor); + } break; + case GGML_OP_SUM_ROWS: + { + ggml_compute_forward_sum_rows(params, tensor->src[0], tensor); + } break; + case GGML_OP_MEAN: + { + ggml_compute_forward_mean(params, tensor->src[0], tensor); + } break; + case GGML_OP_ARGMAX: + { + ggml_compute_forward_argmax(params, tensor->src[0], tensor); + } break; + case GGML_OP_REPEAT: + { + ggml_compute_forward_repeat(params, tensor->src[0], tensor); + } break; + case GGML_OP_REPEAT_BACK: + { + ggml_compute_forward_repeat_back(params, tensor->src[0], tensor); + } break; + case GGML_OP_CONCAT: + { + ggml_compute_forward_concat(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_SILU_BACK: + { + ggml_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_NORM: + { + ggml_compute_forward_norm(params, tensor->src[0], tensor); + } break; + case GGML_OP_RMS_NORM: + { + ggml_compute_forward_rms_norm(params, tensor->src[0], tensor); + } break; + case GGML_OP_RMS_NORM_BACK: + { + ggml_compute_forward_rms_norm_back(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_GROUP_NORM: + { + ggml_compute_forward_group_norm(params, tensor->src[0], tensor); + } break; + case GGML_OP_MUL_MAT: + { + ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_MUL_MAT_ID: + { + ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_OUT_PROD: + { + ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_SCALE: + { + ggml_compute_forward_scale(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_SET: + { + ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_CPY: + { + ggml_compute_forward_cpy(params, tensor->src[0], tensor); + } break; + case GGML_OP_CONT: + { + ggml_compute_forward_cont(params, tensor->src[0], tensor); + } break; + case GGML_OP_RESHAPE: + { + ggml_compute_forward_reshape(params, tensor->src[0], tensor); + } break; + case GGML_OP_VIEW: + { + ggml_compute_forward_view(params, tensor->src[0]); + } break; + case GGML_OP_PERMUTE: + { + ggml_compute_forward_permute(params, tensor->src[0]); + } break; + case GGML_OP_TRANSPOSE: + { + ggml_compute_forward_transpose(params, tensor->src[0]); + } break; + case GGML_OP_GET_ROWS: + { + ggml_compute_forward_get_rows(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_GET_ROWS_BACK: + { + ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_DIAG: + { + ggml_compute_forward_diag(params, tensor->src[0], tensor); + } break; + case GGML_OP_DIAG_MASK_INF: + { + ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor); + } break; + case GGML_OP_DIAG_MASK_ZERO: + { + ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor); + } break; + case GGML_OP_SOFT_MAX: + { + ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_SOFT_MAX_BACK: + { + ggml_compute_forward_soft_max_back(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_ROPE: + { + ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_ROPE_BACK: + { + ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_ALIBI: + { + ggml_compute_forward_alibi(params, tensor->src[0], tensor); + } break; + case GGML_OP_CLAMP: + { + ggml_compute_forward_clamp(params, tensor->src[0], tensor); + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_IM2COL: + { + ggml_compute_forward_im2col(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + ggml_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_POOL_1D: + { + ggml_compute_forward_pool_1d(params, tensor->src[0], tensor); + } break; + case GGML_OP_POOL_2D: + { + ggml_compute_forward_pool_2d(params, tensor->src[0], tensor); + } break; + case GGML_OP_UPSCALE: + { + ggml_compute_forward_upscale(params, tensor->src[0], tensor); + } break; + case GGML_OP_PAD: + { + ggml_compute_forward_pad(params, tensor->src[0], tensor); + } break; + case GGML_OP_ARGSORT: + { + ggml_compute_forward_argsort(params, tensor->src[0], tensor); + } break; + case GGML_OP_LEAKY_RELU: + { + ggml_compute_forward_leaky_relu(params, tensor->src[0], tensor); + } break; + case GGML_OP_FLASH_ATTN: + { + const int32_t t = ggml_get_op_params_i32(tensor, 0); + GGML_ASSERT(t == 0 || t == 1); + const bool masked = t != 0; + ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor); + } break; + case GGML_OP_FLASH_FF: + { + ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor); + } break; + case GGML_OP_FLASH_ATTN_BACK: + { + int32_t t = ggml_get_op_params_i32(tensor, 0); + GGML_ASSERT(t == 0 || t == 1); + bool masked = t != 0; + ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor); + } break; + case GGML_OP_WIN_PART: + { + ggml_compute_forward_win_part(params, tensor->src[0], tensor); + } break; + case GGML_OP_WIN_UNPART: + { + ggml_compute_forward_win_unpart(params, tensor->src[0], tensor); + } break; + case GGML_OP_UNARY: + { + ggml_compute_forward_unary(params, tensor->src[0], tensor); + } break; + case GGML_OP_GET_REL_POS: + { + ggml_compute_forward_get_rel_pos(params, tensor->src[0], tensor); + } break; + case GGML_OP_ADD_REL_POS: + { + ggml_compute_forward_add_rel_pos(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } break; + case GGML_OP_MAP_UNARY: + { + ggml_unary_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_unary(params, tensor->src[0], tensor, fun); + } + break; + case GGML_OP_MAP_BINARY: + { + ggml_binary_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_binary(params, tensor->src[0], tensor->src[1], tensor, fun); + } + break; + case GGML_OP_MAP_CUSTOM1_F32: + { + ggml_custom1_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_custom1_f32(params, tensor->src[0], tensor, fun); + } + break; + case GGML_OP_MAP_CUSTOM2_F32: + { + ggml_custom2_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_custom2_f32(params, tensor->src[0], tensor->src[1], tensor, fun); + } + break; + case GGML_OP_MAP_CUSTOM3_F32: + { + ggml_custom3_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_custom3_f32(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor, fun); + } + break; + case GGML_OP_MAP_CUSTOM1: + { + ggml_compute_forward_map_custom1(params, tensor->src[0], tensor); + } + break; + case GGML_OP_MAP_CUSTOM2: + { + ggml_compute_forward_map_custom2(params, tensor->src[0], tensor->src[1], tensor); + } + break; + case GGML_OP_MAP_CUSTOM3: + { + ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } + break; + case GGML_OP_CROSS_ENTROPY_LOSS: + { + ggml_compute_forward_cross_entropy_loss(params, tensor->src[0], tensor->src[1], tensor); + } + break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + ggml_compute_forward_cross_entropy_loss_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } + break; + case GGML_OP_NONE: + { + // nop + } break; + case GGML_OP_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +//////////////////////////////////////////////////////////////////////////////// + +static size_t ggml_hash_size(size_t min_sz) { + // next primes after powers of two + static const size_t primes[] = { + 2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031, + 2053, 4099, 8209, 16411, 32771, 65537, 131101, + 262147, 524309, 1048583, 2097169, 4194319, 8388617, + 16777259, 33554467, 67108879, 134217757, 268435459, + 536870923, 1073741827, 2147483659 + }; + static const size_t n_primes = sizeof(primes)/sizeof(primes[0]); + + // find the smallest prime that is larger or equal to min_sz + size_t l = 0; + size_t r = n_primes; + while (l < r) { + size_t m = (l + r)/2; + if (primes[m] < min_sz) { + l = m + 1; + } else { + r = m; + } + } + size_t sz = l < n_primes ? primes[l] : min_sz | 1; + return sz; +} + +static size_t ggml_hash(const void * p) { + return (size_t)p; +} + +size_t ggml_hash_find(const struct ggml_hash_set hash_set, struct ggml_tensor * key) { + size_t h = ggml_hash(key) % hash_set.size; + + // linear probing + size_t i = h; + while (hash_set.keys[i] != NULL && hash_set.keys[i] != key) { + i = (i + 1) % hash_set.size; + if (i == h) { + // visited all hash table entries -> not found + return GGML_HASHTABLE_FULL; + } + } + return i; +} + +bool ggml_hash_contains(struct ggml_hash_set hash_set, struct ggml_tensor * key) { + size_t i = ggml_hash_find(hash_set, key); + return i != GGML_HASHTABLE_FULL && hash_set.keys[i] == key; +} + +size_t ggml_hash_insert(struct ggml_hash_set hash_set, struct ggml_tensor * key) { + size_t i = ggml_hash_find(hash_set, key); + + GGML_ASSERT(i != GGML_HASHTABLE_FULL); + + if (hash_set.keys[i] == key) { + return GGML_HASHTABLE_ALREADY_EXISTS; + } + + // insert + GGML_ASSERT(hash_set.keys[i] == NULL); + hash_set.keys[i] = key; + return i; +} + +size_t ggml_hash_find_or_insert(struct ggml_hash_set hash_set, struct ggml_tensor * key) { + size_t i = ggml_hash_find(hash_set, key); + + GGML_ASSERT(i != GGML_HASHTABLE_FULL); + + hash_set.keys[i] = key; + return i; +} + +static struct ggml_hash_set ggml_hash_set_new(size_t size) { + size = ggml_hash_size(size); + struct ggml_hash_set result; + result.size = size; + result.keys = malloc(sizeof(struct ggml_tensor *) * size); + memset(result.keys, 0, sizeof(struct ggml_tensor *) * size); + return result; +} + +static void ggml_hash_set_free(struct ggml_hash_set hash_set) { + free(hash_set.keys); +} + +struct hash_map { + struct ggml_hash_set set; + struct ggml_tensor ** vals; +}; + +static struct hash_map * ggml_new_hash_map(size_t size) { + struct hash_map * result = malloc(sizeof(struct hash_map)); + result->set = ggml_hash_set_new(size); + result->vals = malloc(sizeof(struct ggml_tensor *) * result->set.size); + memset(result->vals, 0, sizeof(struct ggml_tensor *) * result->set.size); + return result; +} + +static void ggml_hash_map_free(struct hash_map * map) { + ggml_hash_set_free(map->set); + free(map->vals); + free(map); +} + +// gradient checkpointing + +static struct ggml_tensor * ggml_recompute_graph_node( + struct ggml_context * ctx, + struct ggml_cgraph * graph, + struct hash_map * replacements, + struct ggml_tensor * node) { + + if (node == NULL) { + return NULL; + } + + if (node->is_param) { + return node; + } + + if (!ggml_hash_contains(graph->visited_hash_table, node)) { + return node; + } + + int count_children = 0; + for (int k = 0; k < GGML_MAX_SRC; ++k) { + if (node->src[k]) { + ++count_children; + } + } + + if (count_children == 0) { + return node; + } + + size_t i = ggml_hash_find(replacements->set, node); + GGML_ASSERT(i != GGML_HASHTABLE_FULL); // assert that not full + if (replacements->set.keys[i] == node) { + return replacements->vals[i]; + } + + struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, GGML_MAX_DIMS, node->ne); + + // insert clone into replacements + GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite + replacements->set.keys[i] = node; + replacements->vals[i] = clone; + + clone->op = node->op; + clone->grad = node->grad; + clone->is_param = node->is_param; + clone->extra = node->extra; + for (int k = 0; k < GGML_MAX_DIMS; ++k) { + clone->nb[k] = node->nb[k]; + } + for (int k = 0; k < GGML_MAX_SRC; ++k) { + clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); + } + if (node->view_src != NULL) { + clone->data = (node->view_src->data == NULL) + ? NULL // view_src not yet allocated + : (char *) node->view_src->data // view_src already allocated + + node->view_offs; + clone->view_src = node->view_src; + clone->view_offs = node->view_offs; + } + + GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t))); + GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME); + memcpy(clone->op_params, node->op_params, sizeof(node->op_params)); + ggml_format_name(clone, "%s (clone)", ggml_get_name(node)); + + return clone; +} + +void ggml_build_backward_gradient_checkpointing( + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + struct ggml_cgraph * gb_tmp, + struct ggml_tensor * * checkpoints, + int n_checkpoints) { + ggml_graph_cpy(gf, gb_tmp); + ggml_build_backward_expand(ctx, gf, gb_tmp, true); + + if (n_checkpoints <= 0) { + ggml_graph_cpy(gb_tmp, gb); + return; + } + + struct hash_map * replacements = ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints); + + // insert checkpoints in replacements + for (int i = 0; i < n_checkpoints; ++i) { + size_t k = ggml_hash_find(replacements->set, checkpoints[i]); + GGML_ASSERT(k != GGML_HASHTABLE_FULL); // assert that not full + GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite + replacements->set.keys[k] = checkpoints[i]; + replacements->vals[k] = checkpoints[i]; + } + + ggml_graph_cpy(gf, gb); + // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes], + // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]), + // by recomputing them from checkpoints + for (int i = gf->n_nodes; in_nodes; ++i) { + struct ggml_tensor * node = gb_tmp->nodes[i]; + for (int k = 0; k < GGML_MAX_SRC; ++k) { + // insert new tensors recomputing src, reusing already made replacements, + // remember replacements: remember new tensors with mapping from corresponding gf nodes + // recurse for input tensors, + // unless (i.e. terminating when) input tensors are replacements (like checkpoints) + node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]); + } + // insert rewritten backward node with replacements made into resulting backward graph gb + ggml_build_forward_expand(gb, node); + } + + ggml_hash_map_free(replacements); +} + +// functions to change gradients considering the case that input a might be initial gradient with zero value + +static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) { + if (ggml_hash_contains(zero_table, a)) { + return b; + } else { + return ggml_add_impl(ctx, a, b, false); + } +} + +static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set zero_table) { + if (ggml_hash_contains(zero_table, a)) { + struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0)); + return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); + } else { + return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); + } +} + +static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) { + if (ggml_hash_contains(zero_table, a)) { + return ggml_repeat(ctx, b, a); + } else { + return ggml_add1_impl(ctx, a, b, false); + } +} + +static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) { + if (ggml_hash_contains(zero_table, a)) { + return ggml_neg(ctx, b); + } else { + return ggml_sub_impl(ctx, a, b, false); + } +} + +static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) { + struct ggml_tensor * src0 = tensor->src[0]; + struct ggml_tensor * src1 = tensor->src[1]; + + switch (tensor->op) { + case GGML_OP_DUP: + { + if (src0->grad) { + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + } + } break; + case GGML_OP_ADD: + { + if (src0->grad) { + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + } + if (src1->grad) { + src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table); + } + } break; + case GGML_OP_ADD1: + { + if (src0->grad) { + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + } + if (src1->grad) { + src1->grad = ggml_add_or_set(ctx, + src1->grad, + ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean + zero_table); + } + } break; + case GGML_OP_ACC: + { + if (src0->grad) { + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + } + if (src1->grad) { + const size_t nb1 = ((int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((int32_t *) tensor->op_params)[2]; + const size_t offset = ((int32_t *) tensor->op_params)[3]; + + struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx, + tensor->grad, + src1->grad->ne[0], + src1->grad->ne[1], + src1->grad->ne[2], + src1->grad->ne[3], + nb1, nb2, nb3, offset); + + src1->grad = + ggml_add_or_set(ctx, + src1->grad, + ggml_reshape(ctx, + ggml_cont(ctx, tensor_grad_view), + src1->grad), + zero_table); + } + } break; + case GGML_OP_SUB: + { + if (src0->grad) { + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + } + if (src1->grad) { + src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table); + } + } break; + case GGML_OP_MUL: + { + if (src0->grad) { + src0->grad = + ggml_add_or_set(ctx, + src0->grad, + ggml_mul(ctx, src1, tensor->grad), + zero_table); + } + if (src1->grad) { + src1->grad = + ggml_add_or_set(ctx, + src1->grad, + ggml_mul(ctx, src0, tensor->grad), + zero_table); + } + } break; + case GGML_OP_DIV: + { + if (src0->grad) { + src0->grad = + ggml_add_or_set(ctx, + src0->grad, + ggml_div(ctx, tensor->grad, src1), + zero_table); + } + if (src1->grad) { + src1->grad = + ggml_sub_or_set(ctx, + src1->grad, + ggml_mul(ctx, + tensor->grad, + ggml_div(ctx, tensor, src1)), + zero_table); + } + } break; + case GGML_OP_SQR: + { + if (src0->grad) { + src0->grad = + ggml_add_or_set(ctx, + src0->grad, + ggml_scale(ctx, + ggml_mul(ctx, src0, tensor->grad), + ggml_new_f32(ctx, 2.0f)), + zero_table); + } + } break; + case GGML_OP_SQRT: + { + if (src0->grad) { + src0->grad = + ggml_add_or_set(ctx, + src0->grad, + ggml_scale(ctx, + ggml_div(ctx, + tensor->grad, + tensor), + ggml_new_f32(ctx, 0.5f)), + zero_table); + } + } break; + case GGML_OP_LOG: + { + if (src0->grad) { + src0->grad = + ggml_add_or_set(ctx, + src0->grad, + ggml_div(ctx, + tensor->grad, + src0), + zero_table); + } + } break; + case GGML_OP_SUM: + { + if (src0->grad) { + src0->grad = + ggml_add1_or_set(ctx, + src0->grad, + tensor->grad, + zero_table); + } + } break; + case GGML_OP_SUM_ROWS: + { + if (src0->grad) { + src0->grad = + ggml_add_or_set(ctx, + src0->grad, + ggml_repeat(ctx, + tensor->grad, + src0->grad), + zero_table); + } + } break; + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + { + GGML_ASSERT(false); // TODO: implement + } break; + case GGML_OP_REPEAT: + { + // necessary for llama + if (src0->grad) { + src0->grad = ggml_add_or_set(ctx, + src0->grad, + ggml_repeat_back(ctx, tensor->grad, src0->grad), + zero_table); + } + } break; + case GGML_OP_REPEAT_BACK: + { + if (src0->grad) { + // TODO: test this + src0->grad = ggml_add_or_set(ctx, + src0->grad, + ggml_repeat(ctx, tensor->grad, src0->grad), + zero_table); + } + } break; + case GGML_OP_CONCAT: + { + GGML_ASSERT(false); // TODO: implement + } break; + case GGML_OP_SILU_BACK: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_NORM: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_RMS_NORM: + { + // necessary for llama + if (src0->grad) { + float eps; + memcpy(&eps, tensor->op_params, sizeof(float)); + + src0->grad = ggml_add_or_set(ctx, + src0->grad, + ggml_rms_norm_back(ctx, src0, tensor->grad, eps), + zero_table); + } + } break; + case GGML_OP_RMS_NORM_BACK: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_GROUP_NORM: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_MUL_MAT: + { + // https://cs231n.github.io/optimization-2/#staged + // # forward pass + // s0 = np.random.randn(5, 10) + // s1 = np.random.randn(10, 3) + // t = s0.dot(s1) + + // # now suppose we had the gradient on t from above in the circuit + // dt = np.random.randn(*t.shape) # same shape as t + // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix + // ds1 = t.T.dot(dt) + + // tensor.shape [m,p,qq,rr] + // src0.shape [n,m,q1,r1] + // src1.shape [n,p,qq,rr] + + // necessary for llama + if (src0->grad) { + struct ggml_tensor * s1_tg = + ggml_out_prod(ctx, // [n,m,qq,rr] + src1, // [n,p,qq,rr] + tensor->grad); // [m,p,qq,rr] + const int64_t qq = s1_tg->ne[2]; + const int64_t rr = s1_tg->ne[3]; + const int64_t q1 = src0->ne[2]; + const int64_t r1 = src0->ne[3]; + const bool ne2_broadcasted = qq > q1; + const bool ne3_broadcasted = rr > r1; + if (ne2_broadcasted || ne3_broadcasted) { + // sum broadcast repetitions of s1_tg into shape of src0 + s1_tg = ggml_repeat_back(ctx, s1_tg, src0); + } + src0->grad = + ggml_add_or_set(ctx, + src0->grad, // [n,m,q1,r1] + s1_tg, // [n,m,q1,r1] + zero_table); + } + if (src1->grad) { + src1->grad = + ggml_add_or_set(ctx, + src1->grad, // [n,p,qq,rr] + // ggml_mul_mat(ctx, // [n,p,qq,rr] + // ggml_cont(ctx, // [m,n,q1,r1] + // ggml_transpose(ctx, src0)), // [m,n,q1,r1] + // tensor->grad), // [m,p,qq,rr] + + // // when src0 is bigger than tensor->grad (this is mostly the case in llama), + // // avoid transpose of src0, rather transpose smaller tensor->grad + // // and then use ggml_out_prod + ggml_out_prod(ctx, // [n,p,qq,rr] + src0, // [n,m,q1,r1] + ggml_transpose(ctx, // [p,m,qq,rr] + tensor->grad)), // [m,p,qq,rr] + zero_table); + } + } break; + case GGML_OP_MUL_MAT_ID: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_OUT_PROD: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_SCALE: + { + // necessary for llama + if (src0->grad) { + src0->grad = + ggml_add_or_set(ctx, + src0->grad, + ggml_scale_impl(ctx, tensor->grad, src1, false), + zero_table); + } + if (src1->grad) { + src1->grad = + ggml_add_or_set(ctx, + src1->grad, + ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)), + zero_table); + } + } break; + case GGML_OP_SET: + { + const size_t nb1 = ((int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((int32_t *) tensor->op_params)[2]; + const size_t offset = ((int32_t *) tensor->op_params)[3]; + + struct ggml_tensor * tensor_grad_view = NULL; + + if (src0->grad || src1->grad) { + GGML_ASSERT(src0->type == tensor->type); + GGML_ASSERT(tensor->grad->type == tensor->type); + GGML_ASSERT(tensor->grad->type == src1->grad->type); + + tensor_grad_view = ggml_view_4d(ctx, + tensor->grad, + src1->grad->ne[0], + src1->grad->ne[1], + src1->grad->ne[2], + src1->grad->ne[3], + nb1, nb2, nb3, offset); + } + + if (src0->grad) { + src0->grad = ggml_add_or_set(ctx, + src0->grad, + ggml_acc_impl(ctx, + tensor->grad, + ggml_neg(ctx, tensor_grad_view), + nb1, nb2, nb3, offset, false), + zero_table); + } + + if (src1->grad) { + src1->grad = + ggml_add_or_set(ctx, + src1->grad, + ggml_reshape(ctx, + ggml_cont(ctx, tensor_grad_view), + src1->grad), + zero_table); + } + } break; + case GGML_OP_CPY: + { + // necessary for llama + // cpy overwrites value of src1 by src0 and returns view(src1) + // the overwriting is mathematically equivalent to: + // tensor = src0 * 1 + src1 * 0 + if (src0->grad) { + // dsrc0 = dtensor * 1 + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + } + if (src1->grad) { + // dsrc1 = dtensor * 0 -> noop + } + } break; + case GGML_OP_CONT: + { + // same as cpy + if (src0->grad) { + GGML_ASSERT(ggml_is_contiguous(src0->grad)); + GGML_ASSERT(ggml_is_contiguous(tensor->grad)); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + } + } break; + case GGML_OP_RESHAPE: + { + // necessary for llama + if (src0->grad) { + src0->grad = + ggml_add_or_set(ctx, src0->grad, + ggml_reshape(ctx, + ggml_is_contiguous(tensor->grad) + ? tensor->grad + : ggml_cont(ctx, tensor->grad), + src0->grad), + zero_table); + } + } break; + case GGML_OP_VIEW: + { + // necessary for llama + if (src0->grad) { + size_t offset; + + memcpy(&offset, tensor->op_params, sizeof(offset)); + + size_t nb1 = tensor->nb[1]; + size_t nb2 = tensor->nb[2]; + size_t nb3 = tensor->nb[3]; + + if (src0->type != src0->grad->type) { + // gradient is typically F32, but src0 could be other type + size_t ng = ggml_element_size(src0->grad); + size_t n0 = ggml_element_size(src0); + GGML_ASSERT(offset % n0 == 0); + GGML_ASSERT(nb1 % n0 == 0); + GGML_ASSERT(nb2 % n0 == 0); + GGML_ASSERT(nb3 % n0 == 0); + offset = (offset / n0) * ng; + nb1 = (nb1 / n0) * ng; + nb2 = (nb2 / n0) * ng; + nb3 = (nb3 / n0) * ng; + } + + src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table); + } + } break; + case GGML_OP_PERMUTE: + { + // necessary for llama + if (src0->grad) { + int32_t * axes = (int32_t *) tensor->op_params; + int axis0 = axes[0] & 0x3; + int axis1 = axes[1] & 0x3; + int axis2 = axes[2] & 0x3; + int axis3 = axes[3] & 0x3; + int axes_backward[4] = {0,0,0,0}; + axes_backward[axis0] = 0; + axes_backward[axis1] = 1; + axes_backward[axis2] = 2; + axes_backward[axis3] = 3; + src0->grad = + ggml_add_or_set(ctx, src0->grad, + ggml_permute(ctx, + tensor->grad, + axes_backward[0], + axes_backward[1], + axes_backward[2], + axes_backward[3]), + zero_table); + } + } break; + case GGML_OP_TRANSPOSE: + { + // necessary for llama + if (src0->grad) { + src0->grad = + ggml_add_or_set(ctx, src0->grad, + ggml_transpose(ctx, tensor->grad), + zero_table); + } + } break; + case GGML_OP_GET_ROWS: + { + // necessary for llama (only for tokenizer) + if (src0->grad) { + src0->grad = + ggml_add_or_set(ctx, src0->grad, + // last ggml_get_rows_back argument src0->grad is only + // necessary to setup correct output shape + ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), + zero_table); + } + if (src1->grad) { + // noop + } + } break; + case GGML_OP_GET_ROWS_BACK: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_DIAG: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_DIAG_MASK_INF: + { + // necessary for llama + if (src0->grad) { + const int n_past = ((int32_t *) tensor->op_params)[0]; + src0->grad = + ggml_add_or_set(ctx, src0->grad, + ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), + zero_table); + } + } break; + case GGML_OP_DIAG_MASK_ZERO: + { + // necessary for llama + if (src0->grad) { + const int n_past = ((int32_t *) tensor->op_params)[0]; + src0->grad = + ggml_add_or_set(ctx, src0->grad, + ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), + zero_table); + } + } break; + case GGML_OP_SOFT_MAX: + { + // necessary for llama + if (src0->grad) { + src0->grad = + ggml_add_or_set(ctx, src0->grad, + ggml_soft_max_back(ctx, tensor->grad, tensor), + zero_table); + } + + } break; + case GGML_OP_SOFT_MAX_BACK: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_ROPE: + { + // necessary for llama + if (src0->grad) { + //const int n_past = ((int32_t *) tensor->op_params)[0]; + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_orig_ctx = ((int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down; + + memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); + memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float)); + memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool)); + + src0->grad = ggml_add_or_set(ctx, + src0->grad, + ggml_rope_back(ctx, + tensor->grad, + src1, + n_dims, + mode, + n_ctx, + n_orig_ctx, + freq_base, + freq_scale, + ext_factor, + attn_factor, + beta_fast, + beta_slow, + xpos_base, + xpos_down), + zero_table); + } + } break; + case GGML_OP_ROPE_BACK: + { + if (src0->grad) { + //const int n_past = ((int32_t *) tensor->op_params)[0]; + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_orig_ctx = ((int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down; + + memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); + memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float)); + memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool)); + + src0->grad = ggml_add_or_set(ctx, + src0->grad, + ggml_rope_impl(ctx, + tensor->grad, + src1, + n_dims, + mode, + n_ctx, + n_orig_ctx, + freq_base, + freq_scale, + ext_factor, + attn_factor, + beta_fast, + beta_slow, + xpos_base, + xpos_down, + false), + zero_table); + } + } break; + case GGML_OP_ALIBI: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CLAMP: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_IM2COL: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_POOL_1D: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_POOL_2D: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_UPSCALE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_PAD: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_ARGSORT: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_LEAKY_RELU: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_FLASH_ATTN: + { + struct ggml_tensor * flash_grad = NULL; + if (src0->grad || src1->grad || tensor->src[2]->grad) { + int32_t t = ggml_get_op_params_i32(tensor, 0); + GGML_ASSERT(t == 0 || t == 1); + bool masked = t != 0; + flash_grad = + ggml_flash_attn_back(ctx, + src0, + src1, + tensor->src[2], + tensor->grad, + masked); + } + + struct ggml_tensor * src2 = tensor->src[2]; + const int64_t elem_q = ggml_nelements(src0); + const int64_t elem_k = ggml_nelements(src1); + const int64_t elem_v = ggml_nelements(src2); + + enum ggml_type result_type = flash_grad->type; + GGML_ASSERT(ggml_blck_size(result_type) == 1); + const size_t tsize = ggml_type_size(result_type); + + const size_t offs_q = 0; + const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); + const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); + + if (src0->grad) { + struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q); + struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0); + src0->grad = ggml_add_or_set(ctx, + src0->grad, + grad_q, + zero_table); + } + if (src1->grad) { + struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k); + struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1); + src1->grad = ggml_add_or_set(ctx, + src1->grad, + grad_k, + zero_table); + } + if (src2->grad) { + struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v); + struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2); + src2->grad = ggml_add_or_set(ctx, + src2->grad, + grad_v, + zero_table); + } + } break; + case GGML_OP_FLASH_FF: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_FLASH_ATTN_BACK: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: + case GGML_OP_UNARY: + { + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_ABS: + { + if (src0->grad) { + src0->grad = + ggml_add_or_set(ctx, + src0->grad, + ggml_mul(ctx, + ggml_sgn(ctx, src0), + tensor->grad), + zero_table); + } + } break; + case GGML_UNARY_OP_SGN: + { + if (src0->grad) { + // noop + } + } break; + case GGML_UNARY_OP_NEG: + { + if (src0->grad) { + src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table); + } + } break; + case GGML_UNARY_OP_STEP: + { + if (src0->grad) { + // noop + } + } break; + case GGML_UNARY_OP_TANH: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_UNARY_OP_ELU: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_UNARY_OP_RELU: + { + if (src0->grad) { + src0->grad = ggml_add_or_set(ctx, + src0->grad, + ggml_mul(ctx, + ggml_step(ctx, src0), + tensor->grad), + zero_table); + } + } break; + case GGML_UNARY_OP_GELU: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_UNARY_OP_GELU_QUICK: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_UNARY_OP_SILU: + { + // necessary for llama + if (src0->grad) { + src0->grad = ggml_add_or_set(ctx, + src0->grad, + ggml_silu_back(ctx, src0, tensor->grad), + zero_table); + } + } break; + default: + GGML_ASSERT(false); + } + } break; + case GGML_OP_GET_REL_POS: + case GGML_OP_ADD_REL_POS: + case GGML_OP_MAP_UNARY: + case GGML_OP_MAP_BINARY: + case GGML_OP_MAP_CUSTOM1_F32: + case GGML_OP_MAP_CUSTOM2_F32: + case GGML_OP_MAP_CUSTOM3_F32: + case GGML_OP_MAP_CUSTOM1: + case GGML_OP_MAP_CUSTOM2: + case GGML_OP_MAP_CUSTOM3: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + { + if (src0->grad) { + src0->grad = ggml_add_or_set(ctx, + src0->grad, + ggml_cross_entropy_loss_back(ctx, + src0, + src1, + tensor->grad), + zero_table); + } + } break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_NONE: + { + // nop + } break; + case GGML_OP_COUNT: + { + GGML_ASSERT(false); + } break; + } + + for (int i = 0; i < GGML_MAX_SRC; ++i) { + if (tensor->src[i] && tensor->src[i]->grad) { + GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad)); + } + } +} + +static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { + if (node->grad == NULL) { + // this usually happens when we generate intermediate nodes from constants in the backward pass + // it can also happen during forward pass, if the user performs computations with constants + if (node->op != GGML_OP_NONE) { + //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op); + } + } + + // check if already visited + if (ggml_hash_insert(cgraph->visited_hash_table, node) == GGML_HASHTABLE_ALREADY_EXISTS) { + return; + } + + for (int i = 0; i < GGML_MAX_SRC; ++i) { + const int k = + (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : + (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) : + /* unknown order, just fall back to using i*/ i; + if (node->src[k]) { + ggml_visit_parents(cgraph, node->src[k]); + } + } + + if (node->op == GGML_OP_NONE && node->grad == NULL) { + // reached a leaf node, not part of the gradient graph (e.g. a constant) + GGML_ASSERT(cgraph->n_leafs < cgraph->size); + + if (strlen(node->name) == 0) { + ggml_format_name(node, "leaf_%d", cgraph->n_leafs); + } + + cgraph->leafs[cgraph->n_leafs] = node; + cgraph->n_leafs++; + } else { + GGML_ASSERT(cgraph->n_nodes < cgraph->size); + + if (strlen(node->name) == 0) { + ggml_format_name(node, "node_%d", cgraph->n_nodes); + } + + cgraph->nodes[cgraph->n_nodes] = node; + if (cgraph->grads) { + cgraph->grads[cgraph->n_nodes] = node->grad; + } + cgraph->n_nodes++; + } +} + +static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { + if (!expand) { + // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand + ggml_graph_clear(cgraph); + } + + const int n0 = cgraph->n_nodes; + UNUSED(n0); + + ggml_visit_parents(cgraph, tensor); + + const int n_new = cgraph->n_nodes - n0; + GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); + + if (n_new > 0) { + // the last added node should always be starting point + GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor); + } +} + +void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { + ggml_build_forward_impl(cgraph, tensor, true); +} + +void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) { + GGML_ASSERT(gf->n_nodes > 0); + + // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph + if (keep) { + for (int i = 0; i < gf->n_nodes; i++) { + struct ggml_tensor * node = gf->nodes[i]; + + if (node->grad) { + node->grad = ggml_dup_tensor(ctx, node); + gf->grads[i] = node->grad; + } + } + } + + // remember original gradients which start with zero values + struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size); + for (int i = 0; i < gf->n_nodes; i++) { + if (gf->grads[i]) { + ggml_hash_insert(zero_table, gf->grads[i]); + } + } + + for (int i = gf->n_nodes - 1; i >= 0; i--) { + struct ggml_tensor * node = gf->nodes[i]; + + // inplace operations to add gradients are not created by ggml_compute_backward + // use allocator to automatically make inplace operations + if (node->grad) { + ggml_compute_backward(ctx, node, zero_table); + } + } + + for (int i = 0; i < gf->n_nodes; i++) { + struct ggml_tensor * node = gf->nodes[i]; + + if (node->is_param) { + GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); + ggml_build_forward_expand(gb, node->grad); + } + } + + ggml_hash_set_free(zero_table); +} + +static size_t ggml_graph_nbytes(size_t size, bool grads) { + size_t nbytes = sizeof(struct ggml_cgraph); + nbytes += size * sizeof(struct ggml_tensor *) * 2; // leafs + nodes + if (grads) { + nbytes += size * sizeof(struct ggml_tensor *); // grads + } + nbytes += ggml_hash_size(size * 2) * sizeof(struct ggml_tensor *); // hash set + return nbytes; +} + +size_t ggml_graph_overhead_custom(size_t size, bool grads) { + return GGML_OBJECT_SIZE + GGML_PAD(ggml_graph_nbytes(size, grads), GGML_MEM_ALIGN); +} + +size_t ggml_graph_overhead(void) { + return ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, false); +} + +struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads) { + const size_t obj_size = ggml_graph_nbytes(size, grads); + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, obj_size); + struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs); + + struct ggml_tensor ** data_start = (struct ggml_tensor **) (cgraph + 1); + + size_t hash_size = ggml_hash_size(size * 2); + struct ggml_tensor ** nodes_ptr = data_start; + struct ggml_tensor ** leafs_ptr = nodes_ptr + size; + struct ggml_tensor ** hash_keys_ptr = leafs_ptr + size; + struct ggml_tensor ** grads_ptr = grads ? hash_keys_ptr + hash_size : NULL; + + // check that we allocated the correct amount of memory + assert(obj_size == (size_t) ( + (grads ? (char *)(grads_ptr + size) : (char *)(hash_keys_ptr + hash_size)) - (char *)cgraph)); + + memset(hash_keys_ptr, 0, hash_size * sizeof(struct ggml_tensor *)); + + *cgraph = (struct ggml_cgraph) { + /*.size =*/ size, + /*.n_nodes =*/ 0, + /*.n_leafs =*/ 0, + /*.nodes =*/ nodes_ptr, + /*.grads =*/ grads_ptr, + /*.leafs =*/ leafs_ptr, + /*.hash_table =*/ { hash_size, hash_keys_ptr }, + /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + }; + + return cgraph; +} + +struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) { + return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false); +} + +struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) { + struct ggml_cgraph cgraph = { + /*.size =*/ 0, + /*.n_nodes =*/ i1 - i0, + /*.n_leafs =*/ 0, + /*.nodes =*/ cgraph0->nodes + i0, + /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL, + /*.leafs =*/ NULL, + /*.hash_table =*/ { 0, NULL }, + /*.order =*/ cgraph0->order, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + }; + + return cgraph; +} + +void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { + GGML_ASSERT(dst->size >= src->n_leafs); + GGML_ASSERT(dst->size >= src->n_nodes); + GGML_ASSERT(dst->visited_hash_table.size >= src->visited_hash_table.size); + + dst->n_leafs = src->n_leafs; + dst->n_nodes = src->n_nodes; + dst->order = src->order; + + for (int i = 0; i < src->n_leafs; ++i) { + dst->leafs[i] = src->leafs[i]; + } + + for (int i = 0; i < src->n_nodes; ++i) { + dst->nodes[i] = src->nodes[i]; + } + + if (src->grads) { + GGML_ASSERT(dst->grads != NULL); + for (int i = 0; i < src->n_nodes; ++i) { + dst->grads[i] = src->grads[i]; + } + } + + for (size_t i = 0; i < src->visited_hash_table.size; ++i) { + if (src->visited_hash_table.keys[i]) { + ggml_hash_insert(dst->visited_hash_table, src->visited_hash_table.keys[i]); + } + } +} + +struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { + struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL); + ggml_graph_cpy(cgraph, result); + return result; +} + +void ggml_graph_reset(struct ggml_cgraph * cgraph) { + GGML_ASSERT(cgraph->grads != NULL); + + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * grad = cgraph->grads[i]; + + if (grad) { + ggml_set_zero(grad); + } + } +} + +void ggml_graph_clear(struct ggml_cgraph * cgraph) { + cgraph->n_leafs = 0; + cgraph->n_nodes = 0; + memset(cgraph->visited_hash_table.keys, 0, cgraph->visited_hash_table.size * sizeof(struct ggml_tensor *)); +} + +// +// thread data +// +// synchronization is done via busy loops +// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops +// + +#ifdef __APPLE__ + +//#include +// +//typedef os_unfair_lock ggml_lock_t; +// +//#define ggml_lock_init(x) UNUSED(x) +//#define ggml_lock_destroy(x) UNUSED(x) +//#define ggml_lock_lock os_unfair_lock_lock +//#define ggml_lock_unlock os_unfair_lock_unlock +// +//#define GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT + +typedef int ggml_lock_t; + +#define ggml_lock_init(x) UNUSED(x) +#define ggml_lock_destroy(x) UNUSED(x) +#define ggml_lock_lock(x) UNUSED(x) +#define ggml_lock_unlock(x) UNUSED(x) + +#define GGML_LOCK_INITIALIZER 0 + +typedef pthread_t ggml_thread_t; + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + +#else + +//typedef pthread_spinlock_t ggml_lock_t; + +//#define ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE) +//#define ggml_lock_destroy pthread_spin_destroy +//#define ggml_lock_lock pthread_spin_lock +//#define ggml_lock_unlock pthread_spin_unlock + +typedef int ggml_lock_t; + +#define ggml_lock_init(x) UNUSED(x) +#define ggml_lock_destroy(x) UNUSED(x) +#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64)) +#define ggml_lock_lock(x) _mm_pause() +#else +#define ggml_lock_lock(x) UNUSED(x) +#endif +#define ggml_lock_unlock(x) UNUSED(x) + +#define GGML_LOCK_INITIALIZER 0 + +typedef pthread_t ggml_thread_t; + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + +#endif + +// Android's libc implementation "bionic" does not support setting affinity +#if defined(__linux__) && !defined(__BIONIC__) +static void set_numa_thread_affinity(int thread_n, int n_threads) { + if (!ggml_is_numa()) { + return; + } + + // run thread on node_num thread_n / (threads per node) + const int node_num = thread_n / ((n_threads + g_state.numa.n_nodes - 1) / g_state.numa.n_nodes); + struct ggml_numa_node * node = &g_state.numa.nodes[node_num]; + size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); + + cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); + CPU_ZERO_S(setsize, cpus); + for (size_t i = 0; i < node->n_cpus; ++i) { + CPU_SET_S(node->cpus[i], setsize, cpus); + } + + int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); + if (rv) { + fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", + strerror(rv)); + } + + CPU_FREE(cpus); +} + +static void clear_numa_thread_affinity(void) { + if (!ggml_is_numa()) { + return; + } + + size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); + + cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); + CPU_ZERO_S(setsize, cpus); + for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) { + CPU_SET_S(i, setsize, cpus); + } + + int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); + if (rv) { + fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", + strerror(rv)); + } + + CPU_FREE(cpus); +} +#else +// TODO: Windows etc. +// (the linux implementation may also work on BSD, someone should test) +static void set_numa_thread_affinity(int thread_n, int n_threads) { UNUSED(thread_n); UNUSED(n_threads); } +static void clear_numa_thread_affinity(void) {} +#endif + +struct ggml_compute_state_shared { + const struct ggml_cgraph * cgraph; + const struct ggml_cplan * cplan; + + int64_t perf_node_start_cycles; + int64_t perf_node_start_time_us; + + const int n_threads; + + // synchronization primitives + atomic_int n_active; // num active threads + atomic_int node_n; // active graph node + + bool (*abort_callback)(void * data); // abort ggml_graph_compute when true + void * abort_callback_data; +}; + +struct ggml_compute_state { + ggml_thread_t thrd; + int ith; + struct ggml_compute_state_shared * shared; +}; + +static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) { + int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles; + int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us; + + node->perf_runs++; + node->perf_cycles += cycles_cur; + node->perf_time_us += time_us_cur; +} + +static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { + int n_tasks = 0; + + switch (node->op) { + case GGML_OP_CPY: + case GGML_OP_DUP: + case GGML_OP_ADD: + case GGML_OP_ADD1: + case GGML_OP_ACC: + { + n_tasks = n_threads; + } break; + case GGML_OP_SUB: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_LOG: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_LEAKY_RELU: + { + n_tasks = 1; + } break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_RELU: + { + n_tasks = 1; + } break; + + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + { + n_tasks = n_threads; + } break; + default: + GGML_ASSERT(false); + } + break; + case GGML_OP_SILU_BACK: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_GROUP_NORM: + case GGML_OP_CONCAT: + { + n_tasks = n_threads; + } break; + case GGML_OP_MUL_MAT: + { + n_tasks = n_threads; + + // TODO: use different scheduling for different matrix sizes + //const int nr0 = ggml_nrows(node->src[0]); + //const int nr1 = ggml_nrows(node->src[1]); + + //n_tasks = MIN(n_threads, MAX(1, nr0/128)); + //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks); + +#if defined(GGML_USE_CUBLAS) + if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + } +#elif defined(GGML_USE_CLBLAST) + if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + } +#endif +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + } +#endif + } break; + case GGML_OP_MUL_MAT_ID: + { + n_tasks = n_threads; + } break; + case GGML_OP_OUT_PROD: + { + n_tasks = n_threads; + } break; + case GGML_OP_SCALE: + case GGML_OP_SET: + case GGML_OP_CONT: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_GET_ROWS: + case GGML_OP_GET_ROWS_BACK: + case GGML_OP_DIAG: + { + n_tasks = 1; + } break; + case GGML_OP_DIAG_MASK_ZERO: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_ADD_REL_POS: + { + n_tasks = n_threads; + } break; + case GGML_OP_ALIBI: + { + n_tasks = 1; //TODO + } break; + case GGML_OP_CLAMP: + { + n_tasks = 1; //TODO + } break; + case GGML_OP_SOFT_MAX: + { + n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0])); + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + n_tasks = n_threads; + } break; + case GGML_OP_IM2COL: + { + n_tasks = n_threads; + } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + n_tasks = n_threads; + } break; + case GGML_OP_POOL_1D: + case GGML_OP_POOL_2D: + { + n_tasks = 1; + } break; + case GGML_OP_UPSCALE: + { + n_tasks = n_threads; + } break; + case GGML_OP_PAD: + { + n_tasks = n_threads; + } break; + case GGML_OP_ARGSORT: + { + n_tasks = n_threads; + } break; + case GGML_OP_FLASH_ATTN: + { + n_tasks = n_threads; + } break; + case GGML_OP_FLASH_FF: + { + n_tasks = n_threads; + } break; + case GGML_OP_FLASH_ATTN_BACK: + { + n_tasks = n_threads; + } break; + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: + case GGML_OP_GET_REL_POS: + case GGML_OP_MAP_UNARY: + case GGML_OP_MAP_BINARY: + case GGML_OP_MAP_CUSTOM1_F32: + case GGML_OP_MAP_CUSTOM2_F32: + case GGML_OP_MAP_CUSTOM3_F32: + { + n_tasks = 1; + } break; + case GGML_OP_MAP_CUSTOM1: + { + struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) node->op_params; + if (p->n_tasks == GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p->n_tasks, n_threads); + } + } break; + case GGML_OP_MAP_CUSTOM2: + { + struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) node->op_params; + if (p->n_tasks == GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p->n_tasks, n_threads); + } + } break; + case GGML_OP_MAP_CUSTOM3: + { + struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) node->op_params; + if (p->n_tasks == GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p->n_tasks, n_threads); + } + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + { + n_tasks = n_threads; + } break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + n_tasks = n_threads; + } break; + case GGML_OP_NONE: + { + n_tasks = 1; + } break; + case GGML_OP_COUNT: + { + GGML_ASSERT(false); + } break; + default: + { + fprintf(stderr, "%s: op not implemented: ", __func__); + if (node->op < GGML_OP_COUNT) { + fprintf(stderr, "%s\n", ggml_op_name(node->op)); + } else { + fprintf(stderr, "%d\n", node->op); + } + GGML_ASSERT(false); + } break; + } + + assert(n_tasks > 0); + + return n_tasks; +} + +static thread_ret_t ggml_graph_compute_thread(void * data) { + struct ggml_compute_state * state = (struct ggml_compute_state *) data; + + const struct ggml_cgraph * cgraph = state->shared->cgraph; + const struct ggml_cplan * cplan = state->shared->cplan; + + const int n_threads = state->shared->n_threads; + + set_numa_thread_affinity(state->ith, n_threads); + + int node_n = -1; + + while (true) { + if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { + state->shared->node_n += 1; + return (thread_ret_t) GGML_EXIT_ABORTED; + } + if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { + // all other threads are finished and spinning + // do finalize and init here so we don't have synchronize again + struct ggml_compute_params params = { + /*.type =*/ GGML_TASK_FINALIZE, + /*.ith =*/ 0, + /*.nth =*/ 0, + /*.wsize =*/ cplan->work_size, + /*.wdata =*/ cplan->work_data, + }; + + if (node_n != -1) { + /* FINALIZE */ + struct ggml_tensor * node = cgraph->nodes[node_n]; + if (GGML_OP_HAS_FINALIZE[node->op]) { + params.nth = ggml_get_n_tasks(node, n_threads); + ggml_compute_forward(¶ms, node); + } + ggml_graph_compute_perf_stats_node(node, state->shared); + } + + // distribute new work or execute it direct if 1T + while (++node_n < cgraph->n_nodes) { + GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes); + + struct ggml_tensor * node = cgraph->nodes[node_n]; + const int n_tasks = ggml_get_n_tasks(node, n_threads); + + state->shared->perf_node_start_cycles = ggml_perf_cycles(); + state->shared->perf_node_start_time_us = ggml_perf_time_us(); + + params.nth = n_tasks; + + /* INIT */ + if (GGML_OP_HAS_INIT[node->op]) { + params.type = GGML_TASK_INIT; + ggml_compute_forward(¶ms, node); + } + + if (n_tasks == 1) { + // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, + // they do something more efficient than spinning (?) + params.type = GGML_TASK_COMPUTE; + ggml_compute_forward(¶ms, node); + + if (GGML_OP_HAS_FINALIZE[node->op]) { + params.type = GGML_TASK_FINALIZE; + ggml_compute_forward(¶ms, node); + } + + ggml_graph_compute_perf_stats_node(node, state->shared); + } else { + break; + } + + if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { + break; + } + } + + atomic_store(&state->shared->n_active, n_threads); + atomic_store(&state->shared->node_n, node_n); + } else { + // wait for other threads to finish + const int last = node_n; + while (true) { + // TODO: this sched_yield can have significant impact on the performance - either positive or negative + // depending on the workload and the operating system. + // since it is not clear what is the best approach, it should potentially become user-configurable + // ref: https://github.com/ggerganov/ggml/issues/291 +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + sched_yield(); +#endif + + node_n = atomic_load(&state->shared->node_n); + if (node_n != last) break; + }; + } + + // check if we should stop + if (node_n >= cgraph->n_nodes) break; + + /* COMPUTE */ + struct ggml_tensor * node = cgraph->nodes[node_n]; + const int n_tasks = ggml_get_n_tasks(node, n_threads); + + struct ggml_compute_params params = { + /*.type =*/ GGML_TASK_COMPUTE, + /*.ith =*/ state->ith, + /*.nth =*/ n_tasks, + /*.wsize =*/ cplan->work_size, + /*.wdata =*/ cplan->work_data, + }; + + if (state->ith < n_tasks) { + ggml_compute_forward(¶ms, node); + } + } + + return GGML_EXIT_SUCCESS; +} + +struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { + if (n_threads <= 0) { + n_threads = GGML_DEFAULT_N_THREADS; + } + + size_t work_size = 0; + + struct ggml_cplan cplan; + memset(&cplan, 0, sizeof(struct ggml_cplan)); + + // thread scheduling for the different operations + work buffer size estimation + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + const int n_tasks = ggml_get_n_tasks(node, n_threads); + + size_t cur = 0; + + switch (node->op) { + case GGML_OP_CPY: + case GGML_OP_DUP: + { + if (ggml_is_quantized(node->type)) { + cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; + } + } break; + case GGML_OP_ADD: + case GGML_OP_ADD1: + { + if (ggml_is_quantized(node->src[0]->type)) { + cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; + } + } break; + case GGML_OP_ACC: + { + if (ggml_is_quantized(node->src[0]->type)) { + cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; + } + } break; + case GGML_OP_MUL_MAT: + { + const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type; + +#if defined(GGML_USE_CLBLAST) + if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) { + cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node); + } else +#endif +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) { + if (node->src[0]->type != GGML_TYPE_F32) { + // here we need memory just for single 2D matrix from src0 + cur = ggml_type_size(GGML_TYPE_F32)*(node->src[0]->ne[0]*node->src[0]->ne[1]); + } + } else +#endif + if (node->src[1]->type != vec_dot_type) { + cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1])); + } + } break; + case GGML_OP_MUL_MAT_ID: + { + const struct ggml_tensor * src0 = node->src[2]; + const struct ggml_tensor * src1 = node->src[1]; + const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; + if (src1->type != vec_dot_type) { + cur = ggml_row_size(vec_dot_type, ggml_nelements(src1)); + } + const int n_as = ggml_get_op_params_i32(node, 1); + cur = GGML_PAD(cur, sizeof(int64_t)); // align + cur += n_as * sizeof(int64_t); // matrix_row_counts + cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows + } break; + case GGML_OP_OUT_PROD: + { + if (ggml_is_quantized(node->src[0]->type)) { + cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; + } + } break; + case GGML_OP_SOFT_MAX: + { + cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + GGML_ASSERT(node->src[0]->ne[3] == 1); + GGML_ASSERT(node->src[1]->ne[2] == 1); + GGML_ASSERT(node->src[1]->ne[3] == 1); + + const int64_t ne00 = node->src[0]->ne[0]; // K + const int64_t ne01 = node->src[0]->ne[1]; // Cout + const int64_t ne02 = node->src[0]->ne[2]; // Cin + + const int64_t ne10 = node->src[1]->ne[0]; // L + const int64_t ne11 = node->src[1]->ne[1]; // Cin + + if (node->src[0]->type == GGML_TYPE_F16 && + node->src[1]->type == GGML_TYPE_F32) { + cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02; + cur += sizeof(ggml_fp16_t)*ne10*ne11; + } else if (node->src[0]->type == GGML_TYPE_F32 && + node->src[1]->type == GGML_TYPE_F32) { + cur += sizeof(float)*ne00*ne01*ne02; + cur += sizeof(float)*ne10*ne11; + } else { + GGML_ASSERT(false); + } + } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + const int64_t ne00 = node->src[0]->ne[0]; // W + const int64_t ne01 = node->src[0]->ne[1]; // H + const int64_t ne02 = node->src[0]->ne[2]; // Channels Out + const int64_t ne03 = node->src[0]->ne[3]; // Channels In + + const int64_t ne10 = node->src[1]->ne[0]; // W + const int64_t ne11 = node->src[1]->ne[1]; // H + const int64_t ne12 = node->src[1]->ne[2]; // Channels In + + cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; + cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; + } break; + case GGML_OP_FLASH_ATTN: + { + const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); + + if (node->src[1]->type == GGML_TYPE_F32) { + cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 + } else if (node->src[1]->type == GGML_TYPE_F16) { + cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 + } + } break; + case GGML_OP_FLASH_FF: + { + if (node->src[1]->type == GGML_TYPE_F32) { + cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 + } else if (node->src[1]->type == GGML_TYPE_F16) { + cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 + } + } break; + case GGML_OP_FLASH_ATTN_BACK: + { + const int64_t D = node->src[0]->ne[0]; + const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); + const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back + if (node->src[1]->type == GGML_TYPE_F32) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } else if (node->src[1]->type == GGML_TYPE_F16) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } + } break; + + case GGML_OP_CROSS_ENTROPY_LOSS: + { + cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); + } break; + case GGML_OP_COUNT: + { + GGML_ASSERT(false); + } break; + default: + break; + } + + work_size = MAX(work_size, cur); + } + + if (work_size > 0) { + work_size += CACHE_LINE_SIZE*(n_threads - 1); + } + + cplan.n_threads = n_threads; + cplan.work_size = work_size; + cplan.work_data = NULL; + + return cplan; +} + +int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { + { + GGML_ASSERT(cplan); + GGML_ASSERT(cplan->n_threads > 0); + + if (cplan->work_size > 0) { + GGML_ASSERT(cplan->work_data); + } + } + + const int n_threads = cplan->n_threads; + + struct ggml_compute_state_shared state_shared = { + /*.cgraph =*/ cgraph, + /*.cgraph_plan =*/ cplan, + /*.perf_node_start_cycles =*/ 0, + /*.perf_node_start_time_us =*/ 0, + /*.n_threads =*/ n_threads, + /*.n_active =*/ n_threads, + /*.node_n =*/ -1, + /*.abort_callback =*/ NULL, + /*.abort_callback_data =*/ NULL, + }; + struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); + + // create thread pool + if (n_threads > 1) { + for (int j = 1; j < n_threads; ++j) { + workers[j] = (struct ggml_compute_state) { + .thrd = 0, + .ith = j, + .shared = &state_shared, + }; + + const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); + GGML_ASSERT(rc == 0); + UNUSED(rc); + } + } + + workers[0].ith = 0; + workers[0].shared = &state_shared; + + const int64_t perf_start_cycles = ggml_perf_cycles(); + const int64_t perf_start_time_us = ggml_perf_time_us(); + + // this is a work thread too + int compute_status = (size_t) ggml_graph_compute_thread(&workers[0]); + + // don't leave affinity set on the main thread + clear_numa_thread_affinity(); + + // join or kill thread pool + if (n_threads > 1) { + for (int j = 1; j < n_threads; j++) { + const int rc = ggml_thread_join(workers[j].thrd, NULL); + GGML_ASSERT(rc == 0); + } + } + + // performance stats (graph) + { + int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles; + int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us; + + cgraph->perf_runs++; + cgraph->perf_cycles += perf_cycles_cur; + cgraph->perf_time_us += perf_time_us_cur; + + GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n", + __func__, cgraph->perf_runs, + (double) perf_cycles_cur / (double) ggml_cycles_per_ms(), + (double) cgraph->perf_cycles / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs, + (double) perf_time_us_cur / 1000.0, + (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs); + } + + return compute_status; +} + +void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) { + struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads); + + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); + + cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; + + ggml_graph_compute(cgraph, &cplan); +} + +struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) { + for (int i = 0; i < cgraph->n_leafs; i++) { + struct ggml_tensor * leaf = cgraph->leafs[i]; + + if (strcmp(leaf->name, name) == 0) { + return leaf; + } + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + if (strcmp(node->name, name) == 0) { + return node; + } + } + + return NULL; +} + +static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fout) { + const int64_t * ne = tensor->ne; + const size_t * nb = tensor->nb; + + fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", + ggml_type_name(tensor->type), + ggml_op_name (tensor->op), + ggml_n_dims(tensor), + ne[0], ne[1], ne[2], ne[3], + nb[0], nb[1], nb[2], nb[3], + tensor->data, + tensor->name); +} + +static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char * arg, FILE * fout) { + const int64_t * ne = tensor->ne; + const size_t * nb = tensor->nb; + + fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", + arg, + ggml_type_name(tensor->type), + ggml_op_name (tensor->op), + ggml_n_dims(tensor), + ne[0], ne[1], ne[2], ne[3], + nb[0], nb[1], nb[2], nb[3], + tensor->data, + tensor->name); +} + +void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { + uint64_t size_eval = 0; + + // compute size of intermediate results + // TODO: does not take into account scratch buffers !!!! + for (int i = 0; i < cgraph->n_nodes; ++i) { + size_eval += ggml_nbytes_pad(cgraph->nodes[i]); + } + + // print + { + FILE * fout = stdout; + + fprintf(fout, "\n"); + fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC); + fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION); + fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs); + fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes); + fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval); + + // header + fprintf(fout, "\n"); + fprintf(fout, "%-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %16s %16s\n", + "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "DATA", "NAME"); + + for (int i = 0; i < cgraph->n_leafs; ++i) { + ggml_graph_export_leaf(cgraph->leafs[i], fout); + + GGML_ASSERT(cgraph->leafs[i]->op == GGML_OP_NONE); + GGML_ASSERT(cgraph->leafs[i]->src[0] == NULL); + GGML_ASSERT(cgraph->leafs[i]->src[1] == NULL); + } + + // header + fprintf(fout, "\n"); + fprintf(fout, "%-6s %-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %8s %16s %16s\n", + "ARG", "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "NTASKS", "DATA", "NAME"); + + for (int i = 0; i < cgraph->n_nodes; ++i) { + ggml_graph_export_node(cgraph->nodes[i], "DST", fout); + + for (int j = 0; j < GGML_MAX_SRC; ++j) { + if (cgraph->nodes[i]->src[j]) { + ggml_graph_export_node(cgraph->nodes[i]->src[j], "SRC", fout); + } + } + + fprintf(fout, "\n"); + } + + fprintf(fout, "\n"); + } + + // write binary data + { + FILE * fout = fopen(fname, "wb"); + + if (!fout) { + fprintf(stderr, "%s: failed to open %s\n", __func__, fname); + return; + } + + // header + { + const uint32_t magic = GGML_FILE_MAGIC; + const uint32_t version = GGML_FILE_VERSION; + const uint32_t n_leafs = cgraph->n_leafs; + const uint32_t n_nodes = cgraph->n_nodes; + + fwrite(&magic, sizeof(uint32_t), 1, fout); + fwrite(&version, sizeof(uint32_t), 1, fout); + fwrite(&n_leafs, sizeof(uint32_t), 1, fout); + fwrite(&n_nodes, sizeof(uint32_t), 1, fout); + fwrite(&size_eval, sizeof(uint64_t), 1, fout); + } + + // leafs + { + for (int i = 0; i < cgraph->n_leafs; ++i) { + const struct ggml_tensor * tensor = cgraph->leafs[i]; + + const uint32_t type = tensor->type; + const uint32_t op = tensor->op; + + fwrite(&type, sizeof(uint32_t), 1, fout); + fwrite(&op, sizeof(uint32_t), 1, fout); + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + const uint64_t ne = tensor->ne[j]; + const uint64_t nb = tensor->nb[j]; + + fwrite(&ne, sizeof(uint64_t), 1, fout); + fwrite(&nb, sizeof(uint64_t), 1, fout); + } + + fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); + fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout); + + // dump the data + // TODO: pad this to 32 byte boundary + { + const size_t size = ggml_nbytes(tensor); + + fwrite(tensor->data, sizeof(char), size, fout); + } + } + } + + // nodes + { + for (int i = 0; i < cgraph->n_nodes; ++i) { + const struct ggml_tensor * tensor = cgraph->nodes[i]; + + const uint32_t type = tensor->type; + const uint32_t op = tensor->op; + + fwrite(&type, sizeof(uint32_t), 1, fout); + fwrite(&op, sizeof(uint32_t), 1, fout); + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + const uint64_t ne = tensor->ne[j]; + const uint64_t nb = tensor->nb[j]; + + fwrite(&ne, sizeof(uint64_t), 1, fout); + fwrite(&nb, sizeof(uint64_t), 1, fout); + } + + fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); + fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout); + + // output the op arguments + { + struct ggml_tensor * args[GGML_MAX_SRC] = { NULL }; + + for (int j = 0; j < GGML_MAX_SRC; ++j) { + args[j] = tensor->src[j]; + } + + for (int j = 0; j < GGML_MAX_SRC; ++j) { + if (args[j]) { + int32_t idx = -1; + + // check if leaf + { + for (int k = 0; k < cgraph->n_leafs; ++k) { + if (args[j] == cgraph->leafs[k]) { + idx = k; + break; + } + } + } + + // check if node + if (idx == -1) { + for (int k = 0; k < cgraph->n_nodes; ++k) { + if (args[j] == cgraph->nodes[k]) { + idx = cgraph->n_leafs + k; + break; + } + } + } + + if (idx == -1) { + fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i); + fclose(fout); + return; + } + + fwrite(&idx, sizeof(int32_t), 1, fout); + } else { + const int32_t nul = -1; + + fwrite(&nul, sizeof(int32_t), 1, fout); + } + } + } + } + } + + fclose(fout); + } +} + +struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval) { + assert(*ctx_data == NULL); + assert(*ctx_eval == NULL); + + struct ggml_cgraph * result = NULL; + + struct ggml_tensor * data = NULL; + + // read file into data + { + FILE * fin = fopen(fname, "rb"); + if (!fin) { + fprintf(stderr, "%s: failed to open %s\n", __func__, fname); + return result; + } + + size_t fsize = 0; + + fseek(fin, 0, SEEK_END); + fsize = ftell(fin); + fseek(fin, 0, SEEK_SET); + + // create the data context + { + const size_t overhead = 1*ggml_tensor_overhead(); + + struct ggml_init_params params = { + .mem_size = fsize + overhead, + .mem_buffer = NULL, + .no_alloc = false, + }; + + *ctx_data = ggml_init(params); + + if (!*ctx_data) { + fprintf(stderr, "%s: failed to create ggml context\n", __func__); + fclose(fin); + return result; + } + } + + data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize); + + { + const size_t ret = fread(data->data, sizeof(char), fsize, fin); + if (ret != fsize) { + fprintf(stderr, "%s: failed to read %s\n", __func__, fname); + fclose(fin); + return result; + } + } + + fclose(fin); + } + + // populate result + { + char * ptr = (char *) data->data; + + const uint32_t magic = *(const uint32_t *) ptr; ptr += sizeof(magic); + + if (magic != GGML_FILE_MAGIC) { + fprintf(stderr, "%s: invalid magic number, got %08x\n", __func__, magic); + return result; + } + + const uint32_t version = *(const uint32_t *) ptr; ptr += sizeof(version); + + if (version != GGML_FILE_VERSION) { + fprintf(stderr, "%s: invalid version number\n", __func__); + return result; + } + + const uint32_t n_leafs = *(const uint32_t *) ptr; ptr += sizeof(n_leafs); + const uint32_t n_nodes = *(const uint32_t *) ptr; ptr += sizeof(n_nodes); + const uint64_t size_eval = *(const uint64_t *) ptr; ptr += sizeof(size_eval); + const int graph_size = MAX(n_leafs, n_nodes); + + // create the data context + { + const size_t overhead = (n_leafs + n_nodes)*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph_size, false); + + struct ggml_init_params params = { + .mem_size = size_eval + overhead, + .mem_buffer = NULL, + .no_alloc = true, + }; + + *ctx_eval = ggml_init(params); + + if (!*ctx_eval) { + fprintf(stderr, "%s: failed to create ggml context\n", __func__); + return result; + } + } + + result = ggml_new_graph_custom(*ctx_eval, graph_size, false); + + result->n_leafs = n_leafs; + result->n_nodes = n_nodes; + + + // leafs + { + uint32_t type; + uint32_t op; + + for (uint32_t i = 0; i < n_leafs; ++i) { + type = *(const uint32_t *) ptr; ptr += sizeof(type); + op = *(const uint32_t *) ptr; ptr += sizeof(op); + + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + uint64_t ne_cur; + uint64_t nb_cur; + + ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur); + nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur); + + ne[j] = ne_cur; + nb[j] = nb_cur; + } + + struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, GGML_MAX_DIMS, ne); + + tensor->op = (enum ggml_op) op; + + memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME; + memcpy(tensor->op_params, ptr, GGML_MAX_OP_PARAMS); ptr += GGML_MAX_OP_PARAMS; + + tensor->data = (void *) ptr; + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + tensor->nb[j] = nb[j]; + } + + result->leafs[i] = tensor; + + ptr += ggml_nbytes(tensor); + + fprintf(stderr, "%s: loaded leaf %d: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor)); + } + } + + ggml_set_no_alloc(*ctx_eval, false); + + // nodes + { + uint32_t type; + uint32_t op; + + for (uint32_t i = 0; i < n_nodes; ++i) { + type = *(const uint32_t *) ptr; ptr += sizeof(type); + op = *(const uint32_t *) ptr; ptr += sizeof(op); + + enum ggml_op eop = (enum ggml_op) op; + + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + uint64_t ne_cur; + uint64_t nb_cur; + + ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur); + nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur); + + ne[j] = ne_cur; + nb[j] = nb_cur; + } + + const char * ptr_name = ptr; ptr += GGML_MAX_NAME; + const char * ptr_op_params = ptr; ptr += GGML_MAX_OP_PARAMS; + + const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += GGML_MAX_SRC*sizeof(int32_t); + + struct ggml_tensor * args[GGML_MAX_SRC] = { NULL }; + + // parse args + for (int j = 0; j < GGML_MAX_SRC; ++j) { + const int32_t arg_idx = ptr_arg_idx[j]; + + if (arg_idx == -1) { + continue; + } + + if (arg_idx < result->n_leafs) { + args[j] = result->leafs[arg_idx]; + } else { + args[j] = result->nodes[arg_idx - result->n_leafs]; + } + } + + // create the tensor + // "view" operations are handled differently + // TODO: handle inplace ops - currently a copy is always made + + struct ggml_tensor * tensor = NULL; + + switch (eop) { + // TODO: implement other view ops + case GGML_OP_RESHAPE: + { + tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]); + } break; + case GGML_OP_VIEW: + { + tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); + + size_t offs; + memcpy(&offs, ptr_op_params, sizeof(offs)); + + tensor->data = ((char *) tensor->data) + offs; + } break; + case GGML_OP_TRANSPOSE: + { + tensor = ggml_transpose(*ctx_eval, args[0]); + } break; + case GGML_OP_PERMUTE: + { + tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); + } break; + default: + { + tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, GGML_MAX_DIMS, ne); + + tensor->op = eop; + } break; + } + + memcpy(tensor->name, ptr_name, GGML_MAX_NAME); + memcpy(tensor->op_params, ptr_op_params, GGML_MAX_OP_PARAMS); + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + tensor->nb[j] = nb[j]; + } + + for (int j = 0; j < GGML_MAX_SRC; ++j) { + tensor->src[j] = args[j]; + } + + result->nodes[i] = tensor; + + fprintf(stderr, "%s: loaded node %d: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor)); + } + } + } + + return result; +} + +void ggml_graph_print(const struct ggml_cgraph * cgraph) { + int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0}; + + GGML_PRINT("=== GRAPH ===\n"); + + GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us); + + GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", + i, + node->ne[0], node->ne[1], node->ne[2], + ggml_op_name(node->op), node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, + (double) node->perf_cycles / (double) ggml_cycles_per_ms(), + (double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs, + (double) node->perf_time_us / 1000.0, + (double) node->perf_time_us / 1000.0 / node->perf_runs); + } + + GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs); + for (int i = 0; i < cgraph->n_leafs; i++) { + struct ggml_tensor * node = cgraph->leafs[i]; + + GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n", + i, + node->ne[0], node->ne[1], + ggml_op_name(node->op), + ggml_get_name(node)); + } + + for (int i = 0; i < GGML_OP_COUNT; i++) { + if (perf_total_per_op_us[i] == 0) { + continue; + } + + GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", ggml_op_name(i), (double) perf_total_per_op_us[i] / 1000.0); + } + + GGML_PRINT("========================================\n"); +} + +// check if node is part of the graph +static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + if (cgraph == NULL) { + return true; + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i] == node) { + return true; + } + } + + return false; +} + +static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * parent = cgraph->nodes[i]; + + if (parent->grad == node) { + return parent; + } + } + + return NULL; +} + +static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) { + struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node); + struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent); + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n", + gparent0 ? (void *) gparent0 : (void *) parent, + gparent0 ? "g" : "x", + gparent ? (void *) gparent : (void *) node, + gparent ? "g" : "x", + gparent ? "empty" : "vee", + gparent ? "dashed" : "solid", + label); +} + +static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) { + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n", + (void *) parent, "x", + (void *) node, "x", + label); +} + +void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) { + char color[16]; + + FILE * fp = fopen(filename, "w"); + GGML_ASSERT(fp); + + fprintf(fp, "digraph G {\n"); + fprintf(fp, " newrank = true;\n"); + fprintf(fp, " rankdir = LR;\n"); + + for (int i = 0; i < gb->n_nodes; i++) { + struct ggml_tensor * node = gb->nodes[i]; + + if (ggml_graph_get_parent(gb, node) != NULL) { + continue; + } + + if (node->is_param) { + snprintf(color, sizeof(color), "yellow"); + } else if (node->grad) { + if (ggml_graph_find(gf, node)) { + snprintf(color, sizeof(color), "green"); + } else { + snprintf(color, sizeof(color), "lightblue"); + } + } else { + snprintf(color, sizeof(color), "white"); + } + + fprintf(fp, " \"%p\" [ " + "style = filled; fillcolor = %s; shape = record; " + "label=\"", + (void *) node, color); + + if (strlen(node->name) > 0) { + fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type)); + } else { + fprintf(fp, "(%s)|", ggml_type_name(node->type)); + } + + if (ggml_is_matrix(node)) { + fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], ggml_op_symbol(node->op)); + } else { + fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op)); + } + + if (node->grad) { + fprintf(fp, " | %s\"; ]\n", ggml_op_symbol(node->grad->op)); + } else { + fprintf(fp, "\"; ]\n"); + } + } + + for (int i = 0; i < gb->n_leafs; i++) { + struct ggml_tensor * node = gb->leafs[i]; + + snprintf(color, sizeof(color), "pink"); + + fprintf(fp, " \"%p\" [ " + "style = filled; fillcolor = %s; shape = record; " + "label=\"", + (void *) node, color); + + if (strlen(node->name) > 0) { + fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type)); + } else { + fprintf(fp, "(%s)|", ggml_type_name(node->type)); + } + + fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]); + if (ggml_nelements(node) < 5) { + fprintf(fp, " | ("); + for (int j = 0; j < ggml_nelements(node); j++) { + if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) { + fprintf(fp, "%d", ggml_get_i32_1d(node, j)); + } + else if (node->type == GGML_TYPE_F32 || node->type == GGML_TYPE_F16) { + fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j)); + } + else { + fprintf(fp, "#"); + } + if (j < ggml_nelements(node) - 1) { + fprintf(fp, ", "); + } + } + fprintf(fp, ")"); + } + fprintf(fp, "\"; ]\n"); + } + + for (int i = 0; i < gb->n_nodes; i++) { + struct ggml_tensor * node = gb->nodes[i]; + + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j]) { + char label[16]; + snprintf(label, sizeof(label), "src %d", j); + ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label); + } + } + } + + for (int i = 0; i < gb->n_leafs; i++) { + struct ggml_tensor * node = gb->leafs[i]; + + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j]) { + char label[16]; + snprintf(label, sizeof(label), "src %d", j); + ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label); + } + } + } + + fprintf(fp, "}\n"); + + fclose(fp); + + GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); +} + +//////////////////////////////////////////////////////////////////////////////// + +static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]) ; + // TODO: add function to set tensor from array + for (int64_t j = 0; j < ne; ++j) { + ggml_set_f32_1d(ps[p], j, x[i++]); + } + } +} + +static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int64_t j = 0; j < ne; ++j) { + x[i++] = ggml_get_f32_1d(ps[p], j); + } + } +} + +static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) { + int64_t i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int64_t j = 0; j < ne; ++j) { + g[i++] = ggml_get_f32_1d(ps[p]->grad, j); + } + } +} + +static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) { + int64_t i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int64_t j = 0; j < ne; ++j) { + g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale; + } + } +} + +// +// ADAM +// +// ref: https://arxiv.org/pdf/1412.6980.pdf +// + +static enum ggml_opt_result ggml_opt_adam( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_opt_params params, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data) { + GGML_ASSERT(ggml_is_scalar(f)); + + // these will store the parameters we want to optimize + struct ggml_tensor * ps[GGML_MAX_PARAMS]; + + int np = 0; + int64_t nx = 0; + for (int i = 0; i < gf->n_nodes; ++i) { + if (gf->nodes[i]->is_param) { + GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); + + GGML_ASSERT(np < GGML_MAX_PARAMS); + + ps[np++] = gf->nodes[i]; + nx += ggml_nelements(gf->nodes[i]); + } + } + + if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) { + int iter = opt->iter; + ggml_opt_init(opt->ctx, opt, params, nx); + opt->iter = iter; + } + + // constants + float sched = params.adam.sched; + const float alpha = params.adam.alpha; + const float decay = params.adam.decay * alpha; + const float beta1 = params.adam.beta1; + const float beta2 = params.adam.beta2; + const float eps = params.adam.eps; + const float gclip = params.adam.gclip; + const int decay_min_ndim = params.adam.decay_min_ndim; + const int n_accum = MAX(1, params.n_gradient_accumulation); + const float accum_norm = 1.0f / (float) n_accum; + + float * g = opt->adam.g->data; // gradients + float * m = opt->adam.m->data; // first moment + float * v = opt->adam.v->data; // second moment + + float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values + + struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); + cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; + + bool cancel = false; + + // compute the function value + float fx = 0; + ggml_set_zero(opt->adam.g); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + callback(callback_data, accum_step, &sched, &cancel); + if (cancel) { + return GGML_OPT_CANCEL; + } + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, &cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + fx += ggml_get_f32_1d(f, 0); + } + fx *= accum_norm; + + opt->adam.fx_prev = fx; + opt->adam.fx_best = opt->adam.fx_prev; + if (pf) { + pf[opt->iter % params.past] = opt->adam.fx_prev; + } + + opt->loss_before = opt->adam.fx_prev; + opt->loss_after = opt->adam.fx_prev; + + // initialize + if (opt->just_initialized) { + opt->adam.n_no_improvement = 0; + opt->just_initialized = false; + } + + float * fx_best = &opt->adam.fx_best; + float * fx_prev = &opt->adam.fx_prev; + int * n_no_improvement = &opt->adam.n_no_improvement; + + int iter0 = opt->iter; + + // run the optimizer + for (int t = 0; t < params.adam.n_iter; ++t) { + opt->iter = iter0 + t + 1; + GGML_PRINT_DEBUG ("=== iter %d ===\n", t); + + GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0)); + GGML_PRINT_DEBUG_5("df/dx0 = %10.6f\n", ggml_get_f32_1d(ps[0]->grad, 0)); + GGML_PRINT_DEBUG_5("df/dx1 = %10.6f\n", ggml_get_f32_1d(ps[1]->grad, 0)); + + for (int i = 0; i < np; ++i) { + GGML_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i, + ggml_get_f32_1d(ps[i], 0), ggml_get_f32_1d(ps[i]->grad, 0)); + } + + const int64_t t_start_wall = ggml_time_us(); + const int64_t t_start_cpu = ggml_cycles(); + UNUSED(t_start_wall); + UNUSED(t_start_cpu); + + { + float gnorm = 1.0f; + if (gclip > 0.0f) { + // gradient clipping + ggml_float sum = 0.0; + for (int64_t i = 0; i < nx; ++i) { + sum += (ggml_float)(g[i]*g[i]); + } + ggml_float norm = sqrt(sum); + if (norm > (ggml_float) gclip) { + gnorm = (float) ((ggml_float) gclip / norm); + } + } + const float beta1h = alpha*sched/(1.0f - powf(beta1, opt->iter)); + const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter)); + int64_t i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]); + const float p_decay = ((ggml_n_dims(ps[p]) >= decay_min_ndim) ? decay : 0.0f) * sched; + for (int64_t j = 0; j < ne; ++j) { + float x = ggml_get_f32_1d(ps[p], j); + float g_ = g[i]*gnorm; + m[i] = m[i]*beta1 + g_*(1.0f - beta1); + v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2); + float mh = m[i]*beta1h; + float vh = v[i]*beta2h; + vh = sqrtf(vh) + eps; + x = x*(1.0f - p_decay) - mh/vh; + ggml_set_f32_1d(ps[p], j, x); + ++i; + } + } + } + + fx = 0; + ggml_set_zero(opt->adam.g); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + callback(callback_data, accum_step, &sched, &cancel); + if (cancel) { + return GGML_OPT_CANCEL;; + } + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, &cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + fx += ggml_get_f32_1d(f, 0); + } + fx *= accum_norm; + + opt->loss_after = fx; + + // check convergence + if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) { + GGML_PRINT_DEBUG("converged\n"); + + return GGML_OPT_OK; + } + + // delta-based convergence test + if (pf != NULL) { + // need at least params.past iterations to start checking for convergence + if (params.past <= iter0 + t) { + const float rate = (pf[(iter0 + t)%params.past] - fx)/fx; + + if (fabsf(rate) < params.delta) { + return GGML_OPT_OK; + } + } + + pf[(iter0 + t)%params.past] = fx; + } + + // check for improvement + if (params.max_no_improvement > 0) { + if (fx_best[0] > fx) { + fx_best[0] = fx; + n_no_improvement[0] = 0; + } else { + ++n_no_improvement[0]; + + if (n_no_improvement[0] >= params.max_no_improvement) { + return GGML_OPT_OK; + } + } + } + + fx_prev[0] = fx; + + { + const int64_t t_end_cpu = ggml_cycles(); + GGML_PRINT_DEBUG("time iter: %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC); + UNUSED(t_end_cpu); + + const int64_t t_end_wall = ggml_time_us(); + GGML_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6); + UNUSED(t_end_wall); + } + } + + return GGML_OPT_DID_NOT_CONVERGE; +} + +// +// L-BFGS +// +// the L-BFGS implementation below is based on the following implementation: +// +// https://github.com/chokkan/liblbfgs +// + +struct ggml_lbfgs_iteration_data { + float alpha; + float ys; + float * s; + float * y; +}; + +static enum ggml_opt_result linesearch_backtracking( + const struct ggml_opt_params * params, + int nx, + float * x, + float * fx, + float * g, + float * d, + float * step, + const float * xp, + struct ggml_tensor * f, + struct ggml_cgraph * gb, + struct ggml_cplan * cplan, + const int np, + struct ggml_tensor * ps[], + bool * cancel, + ggml_opt_callback callback, + void * callback_data) { + int count = 0; + + float width = 0.0f; + float dg = 0.0f; + float finit = 0.0f; + float dginit = 0.0f; + float dgtest = 0.0f; + + const float dec = 0.5f; + const float inc = 2.1f; + + const int n_accum = MAX(1, params->n_gradient_accumulation); + const float accum_norm = 1.0f / (float) n_accum; + + if (*step <= 0.f) { + return GGML_LINESEARCH_INVALID_PARAMETERS; + } + + // compute the initial gradient in the search direction + ggml_vec_dot_f32(nx, &dginit, g, d); + + // make sure that d points to a descent direction + if (0 < dginit) { + return GGML_LINESEARCH_FAIL; + } + + // initialize local variables + finit = *fx; + dgtest = params->lbfgs.ftol*dginit; + + while (true) { + ggml_vec_cpy_f32(nx, x, xp); + ggml_vec_mad_f32(nx, x, d, *step); + + // evaluate the function and gradient values + { + ggml_opt_set_params(np, ps, x); + + *fx = 0; + memset(g, 0, sizeof(float)*nx); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, accum_step, &sched, cancel); + if (*cancel) { + return GGML_OPT_CANCEL; + } + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + *fx += ggml_get_f32_1d(f, 0); + } + *fx *= accum_norm; + + } + + ++count; + + if (*fx > finit + (*step)*dgtest) { + width = dec; + } else { + // Armijo condition is satisfied + if (params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_ARMIJO) { + return count; + } + + ggml_vec_dot_f32(nx, &dg, g, d); + + // check the Wolfe condition + if (dg < params->lbfgs.wolfe * dginit) { + width = inc; + } else { + if(params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE) { + // regular Wolfe conditions + return count; + } + + if(dg > -params->lbfgs.wolfe*dginit) { + width = dec; + } else { + // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) + return count; + } + } + } + + if (*step < params->lbfgs.min_step) { + return GGML_LINESEARCH_MINIMUM_STEP; + } + if (*step > params->lbfgs.max_step) { + return GGML_LINESEARCH_MAXIMUM_STEP; + } + if (params->lbfgs.max_linesearch <= count) { + return GGML_LINESEARCH_MAXIMUM_ITERATIONS; + } + + (*step) *= width; + } + + GGML_UNREACHABLE(); +} + +static enum ggml_opt_result ggml_opt_lbfgs( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_opt_params params, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data) { + if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || + params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { + if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) { + return GGML_OPT_INVALID_WOLFE; + } + } + + const int m = params.lbfgs.m; + + // these will store the parameters we want to optimize + struct ggml_tensor * ps[GGML_MAX_PARAMS]; + + int np = 0; + int nx = 0; + for (int i = 0; i < gf->n_nodes; ++i) { + if (gf->nodes[i]->is_param) { + GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); + + GGML_ASSERT(np < GGML_MAX_PARAMS); + + ps[np++] = gf->nodes[i]; + nx += ggml_nelements(gf->nodes[i]); + } + } + + if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) { + int iter = opt->iter; + ggml_opt_init(ctx, opt, params, nx); + opt->iter = iter; + } + + struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); + cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; + + float * x = opt->lbfgs.x->data; // current parameters + float * xp = opt->lbfgs.xp->data; // previous parameters + float * g = opt->lbfgs.g->data; // current gradient + float * gp = opt->lbfgs.gp->data; // previous gradient + float * d = opt->lbfgs.d->data; // search direction + + float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values + + const int n_accum = MAX(1, params.n_gradient_accumulation); + const float accum_norm = 1.0f / (float) n_accum; + + float fx = 0.0f; // cost function value + float xnorm = 0.0f; // ||x|| + float gnorm = 0.0f; // ||g|| + + // initialize x from the graph nodes + ggml_opt_get_params(np, ps, x); + + // the L-BFGS memory + float * lm_alpha = opt->lbfgs.lmal->data; + float * lm_ys = opt->lbfgs.lmys->data; + float * lm_s = opt->lbfgs.lms->data; + float * lm_y = opt->lbfgs.lmy->data; + + bool cancel = false; + + // evaluate the function value and its gradient + { + ggml_opt_set_params(np, ps, x); + + fx = 0; + memset(g, 0, sizeof(float)*nx); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, accum_step, &sched, &cancel); + if (cancel) { + return GGML_OPT_CANCEL; + } + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, &cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + fx += ggml_get_f32_1d(f, 0); + } + fx *= accum_norm; + + opt->loss_before = fx; + opt->loss_after = fx; + } + + // search direction = -gradient + ggml_vec_neg_f32(nx, d, g); + + // ||x||, ||g|| + ggml_vec_norm_f32(nx, &xnorm, x); + ggml_vec_norm_f32(nx, &gnorm, g); + + if (xnorm < 1.0f) { + xnorm = 1.0f; + } + + // already optimized + if (gnorm/xnorm <= params.lbfgs.eps) { + return GGML_OPT_OK; + } + + if (opt->just_initialized) { + if (pf) { + pf[0] = fx; + } + opt->lbfgs.fx_best = fx; + + // initial step + ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d); + opt->lbfgs.j = 0; + opt->lbfgs.k = 1; + opt->lbfgs.end = 0; + opt->lbfgs.n_no_improvement = 0; + opt->just_initialized = false; + } + + float * fx_best = &opt->lbfgs.fx_best; + float * step = &opt->lbfgs.step; + int * j = &opt->lbfgs.j; + int * k = &opt->lbfgs.k; + int * end = &opt->lbfgs.end; + int * n_no_improvement = &opt->lbfgs.n_no_improvement; + + int ls = 0; + int bound = 0; + + float ys = 0.0f; + float yy = 0.0f; + float beta = 0.0f; + + int it = 0; + + while (true) { + // store the current position and gradient vectors + ggml_vec_cpy_f32(nx, xp, x); + ggml_vec_cpy_f32(nx, gp, g); + + // TODO: instead of passing &cancel here, use the return code of the linesearch + // to determine if the optimization should be cancelled + // this is a simple change, but not doing this atm, since I don't have a nice + // way to test and don't want to break something with so many changes lined up + ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data); + if (cancel) { + return GGML_OPT_CANCEL; + } + + if (ls < 0) { + // linesearch failed - go back to the previous point and return + ggml_vec_cpy_f32(nx, x, xp); + ggml_vec_cpy_f32(nx, g, gp); + + return ls; + } + + opt->loss_after = fx; + + ggml_vec_norm_f32(nx, &xnorm, x); + ggml_vec_norm_f32(nx, &gnorm, g); + + GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0)); + + if (xnorm < 1.0f) { + xnorm = 1.0f; + } + if (gnorm/xnorm <= params.lbfgs.eps) { + // converged + return GGML_OPT_OK; + } + + // delta-based convergence test + if (pf != NULL) { + // need at least params.past iterations to start checking for convergence + if (params.past <= k[0]) { + const float rate = (pf[k[0]%params.past] - fx)/fx; + + if (fabsf(rate) < params.delta) { + return GGML_OPT_OK; + } + } + + pf[k[0]%params.past] = fx; + } + + // check for improvement + if (params.max_no_improvement > 0) { + if (fx < fx_best[0]) { + fx_best[0] = fx; + n_no_improvement[0] = 0; + } else { + n_no_improvement[0]++; + + if (n_no_improvement[0] >= params.max_no_improvement) { + return GGML_OPT_OK; + } + } + } + + if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) { + // reached the maximum number of iterations + return GGML_OPT_DID_NOT_CONVERGE; + } + + // update vectors s and y: + // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}. + // y_{k+1} = g_{k+1} - g_{k}. + // + ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp); + ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp); + + // compute scalars ys and yy: + // ys = y^t \cdot s -> 1 / \rho. + // yy = y^t \cdot y. + // + ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0]*nx]); + ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]); + + lm_ys[end[0]] = ys; + + // find new search direction + // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS + + bound = (m <= k[0]) ? m : k[0]; + k[0]++; + it++; + end[0] = (end[0] + 1)%m; + + // initialize search direction with -g + ggml_vec_neg_f32(nx, d, g); + + j[0] = end[0]; + for (int i = 0; i < bound; ++i) { + j[0] = (j[0] + m - 1) % m; + // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1} + ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d); + lm_alpha[j[0]] /= lm_ys[j[0]]; + // q_{i} = q_{i+1} - \alpha_{i} y_{i} + ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]); + } + + ggml_vec_scale_f32(nx, d, ys/yy); + + for (int i = 0; i < bound; ++i) { + // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i} + ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d); + beta /= lm_ys[j[0]]; + // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j} + ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta); + j[0] = (j[0] + 1)%m; + } + + step[0] = 1.0; + } + + GGML_UNREACHABLE(); +} + +struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { + struct ggml_opt_params result; + + switch (type) { + case GGML_OPT_ADAM: + { + result = (struct ggml_opt_params) { + .type = GGML_OPT_ADAM, + .graph_size = GGML_DEFAULT_GRAPH_SIZE, + .n_threads = 1, // FIXME: GGML_DEFAULT_N_THREADS ? + .past = 0, + .delta = 1e-5f, + + .max_no_improvement = 100, + + .print_forward_graph = true, + .print_backward_graph = true, + + .n_gradient_accumulation = 1, + + .adam = { + .n_iter = 10000, + .sched = 1.000f, + .decay = 0.0f, + .decay_min_ndim = 2, + .alpha = 0.001f, + .beta1 = 0.9f, + .beta2 = 0.999f, + .eps = 1e-8f, + .eps_f = 1e-5f, + .eps_g = 1e-3f, + .gclip = 0.0f, + }, + }; + } break; + case GGML_OPT_LBFGS: + { + result = (struct ggml_opt_params) { + .type = GGML_OPT_LBFGS, + .graph_size = GGML_DEFAULT_GRAPH_SIZE, + .n_threads = 1, + .past = 0, + .delta = 1e-5f, + + .max_no_improvement = 0, + + .print_forward_graph = true, + .print_backward_graph = true, + + .n_gradient_accumulation = 1, + + .lbfgs = { + .m = 6, + .n_iter = 100, + .max_linesearch = 20, + + .eps = 1e-5f, + .ftol = 1e-4f, + .wolfe = 0.9f, + .min_step = 1e-20f, + .max_step = 1e+20f, + + .linesearch = GGML_LINESEARCH_DEFAULT, + }, + }; + } break; + } + + return result; +} + +GGML_API void ggml_opt_init( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_opt_params params, + int64_t nx) { + opt->ctx = ctx; + opt->params = params; + opt->iter = 0; + opt->nx = nx; + opt->just_initialized = true; + if (opt->ctx == NULL) { + struct ggml_init_params ctx_opt_params; + if (opt->params.type == GGML_OPT_ADAM) { + ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3; + if (opt->params.past > 0) { + ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past; + } + } else if (opt->params.type == GGML_OPT_LBFGS) { + ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2); + if (opt->params.past > 0) { + ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past; + } + } + ctx_opt_params.mem_buffer = NULL; + ctx_opt_params.no_alloc = false; + + opt->ctx = ggml_init(ctx_opt_params); + } + switch (opt->params.type) { + case GGML_OPT_ADAM: + { + opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->adam.pf = params.past > 0 + ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past) + : NULL; + ggml_set_zero(opt->adam.m); + ggml_set_zero(opt->adam.v); + if (opt->adam.pf) { + ggml_set_zero(opt->adam.pf); + } + } break; + case GGML_OPT_LBFGS: + { + opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->lbfgs.pf = params.past > 0 + ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past) + : NULL; + opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m); + opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m); + opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m); + opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m); + ggml_set_zero(opt->lbfgs.x); + ggml_set_zero(opt->lbfgs.xp); + ggml_set_zero(opt->lbfgs.g); + ggml_set_zero(opt->lbfgs.gp); + ggml_set_zero(opt->lbfgs.d); + if (opt->lbfgs.pf) { + ggml_set_zero(opt->lbfgs.pf); + } + ggml_set_zero(opt->lbfgs.lmal); + ggml_set_zero(opt->lbfgs.lmys); + ggml_set_zero(opt->lbfgs.lms); + ggml_set_zero(opt->lbfgs.lmy); + } break; + } +} + +enum ggml_opt_result ggml_opt( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f) { + bool free_ctx = false; + if (ctx == NULL) { + struct ggml_init_params params_ctx = { + .mem_size = 16*1024*1024, + .mem_buffer = NULL, + .no_alloc = false, + }; + + ctx = ggml_init(params_ctx); + if (ctx == NULL) { + return GGML_OPT_NO_CONTEXT; + } + + free_ctx = true; + } + + enum ggml_opt_result result = GGML_OPT_OK; + + struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context)); + + ggml_opt_init(ctx, opt, params, 0); + result = ggml_opt_resume(ctx, opt, f); + + if (free_ctx) { + ggml_free(ctx); + } + + return result; +} + +enum ggml_opt_result ggml_opt_resume( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_tensor * f) { + + // build forward + backward compute graphs + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx, opt->params.graph_size, true); + ggml_build_forward_expand(gf, f); + + struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf); + ggml_build_backward_expand(ctx, gf, gb, true); + + return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL); +} + +enum ggml_opt_result ggml_opt_resume_g( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data) { + + // build forward + backward compute graphs + enum ggml_opt_result result = GGML_OPT_OK; + + switch (opt->params.type) { + case GGML_OPT_ADAM: + { + result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb, callback, callback_data); + } break; + case GGML_OPT_LBFGS: + { + result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb, callback, callback_data); + } break; + } + + if (opt->params.print_forward_graph) { + ggml_graph_print (gf); + ggml_graph_dump_dot(gf, NULL, "opt-forward.dot"); + } + + if (opt->params.print_backward_graph) { + ggml_graph_print (gb); + ggml_graph_dump_dot(gb, gf, "opt-backward.dot"); + } + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK4_0 == 0); + const int nb = k / QK4_0; + + for (int b = 0; b < n; b += k) { + block_q4_0 * restrict y = (block_q4_0 *) dst + b/QK4_0; + + quantize_row_q4_0_reference(src + b, y, k); + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK4_0; j += 2) { + const uint8_t vi0 = y[i].qs[j/2] & 0x0F; + const uint8_t vi1 = y[i].qs[j/2] >> 4; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK4_0*sizeof(block_q4_0)); +} + +size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK4_1 == 0); + const int nb = k / QK4_1; + + for (int b = 0; b < n; b += k) { + block_q4_1 * restrict y = (block_q4_1 *) dst + b/QK4_1; + + quantize_row_q4_1_reference(src + b, y, k); + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK4_1; j += 2) { + const uint8_t vi0 = y[i].qs[j/2] & 0x0F; + const uint8_t vi1 = y[i].qs[j/2] >> 4; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK4_1*sizeof(block_q4_1)); +} + +size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + for (int b = 0; b < n; b += k) { + block_q5_0 * restrict y = (block_q5_0 *)dst + b/QK5_0; + + quantize_row_q5_0_reference(src + b, y, k); + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, &y[i].qh, sizeof(qh)); + + for (int j = 0; j < QK5_0; j += 2) { + const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4; + const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12)); + + // cast to 16 bins + const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2; + const uint8_t vi1 = ((y[i].qs[j/2] >> 4) | vh1) / 2; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK5_0*sizeof(block_q5_0)); +} + +size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK5_1 == 0); + const int nb = k / QK5_1; + + for (int b = 0; b < n; b += k) { + block_q5_1 * restrict y = (block_q5_1 *)dst + b/QK5_1; + + quantize_row_q5_1_reference(src + b, y, k); + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, &y[i].qh, sizeof(qh)); + + for (int j = 0; j < QK5_1; j += 2) { + const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4; + const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12)); + + // cast to 16 bins + const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2; + const uint8_t vi1 = ((y[i].qs[j/2] >> 4) | vh1) / 2; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK5_1*sizeof(block_q5_1)); +} + +size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + for (int b = 0; b < n; b += k) { + block_q8_0 * restrict y = (block_q8_0 *)dst + b/QK8_0; + + quantize_row_q8_0_reference(src + b, y, k); + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK8_0; ++j) { + const int8_t vi = y[i].qs[j]; + + hist[vi/16 + 8]++; + } + } + } + + return (n/QK8_0*sizeof(block_q8_0)); +} + +size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) { + size_t result = 0; + switch (type) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(start % QK4_0 == 0); + block_q4_0 * block = (block_q4_0*)dst + start / QK4_0; + result = ggml_quantize_q4_0(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(start % QK4_1 == 0); + block_q4_1 * block = (block_q4_1*)dst + start / QK4_1; + result = ggml_quantize_q4_1(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q5_0: + { + GGML_ASSERT(start % QK5_0 == 0); + block_q5_0 * block = (block_q5_0*)dst + start / QK5_0; + result = ggml_quantize_q5_0(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q5_1: + { + GGML_ASSERT(start % QK5_1 == 0); + block_q5_1 * block = (block_q5_1*)dst + start / QK5_1; + result = ggml_quantize_q5_1(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q8_0: + { + GGML_ASSERT(start % QK8_0 == 0); + block_q8_0 * block = (block_q8_0*)dst + start / QK8_0; + result = ggml_quantize_q8_0(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q2_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q2_K * block = (block_q2_K*)dst + start / QK_K; + result = ggml_quantize_q2_K(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q3_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q3_K * block = (block_q3_K*)dst + start / QK_K; + result = ggml_quantize_q3_K(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q4_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q4_K * block = (block_q4_K*)dst + start / QK_K; + result = ggml_quantize_q4_K(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q5_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q5_K * block = (block_q5_K*)dst + start / QK_K; + result = ggml_quantize_q5_K(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q6_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q6_K * block = (block_q6_K*)dst + start / QK_K; + result = ggml_quantize_q6_K(src + start, block, n, n, hist); + } break; + case GGML_TYPE_F16: + { + int elemsize = sizeof(ggml_fp16_t); + ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n); + result = n * elemsize; + } break; + case GGML_TYPE_F32: + { + int elemsize = sizeof(float); + result = n * elemsize; + memcpy((uint8_t *)dst + start * elemsize, src + start, result); + } break; + default: + assert(false); + } + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +struct gguf_str { + uint64_t n; // GGUFv2 + char * data; +}; + +static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = { + [GGUF_TYPE_UINT8] = sizeof(uint8_t), + [GGUF_TYPE_INT8] = sizeof(int8_t), + [GGUF_TYPE_UINT16] = sizeof(uint16_t), + [GGUF_TYPE_INT16] = sizeof(int16_t), + [GGUF_TYPE_UINT32] = sizeof(uint32_t), + [GGUF_TYPE_INT32] = sizeof(int32_t), + [GGUF_TYPE_FLOAT32] = sizeof(float), + [GGUF_TYPE_BOOL] = sizeof(bool), + [GGUF_TYPE_STRING] = sizeof(struct gguf_str), + [GGUF_TYPE_UINT64] = sizeof(uint64_t), + [GGUF_TYPE_INT64] = sizeof(int64_t), + [GGUF_TYPE_FLOAT64] = sizeof(double), + [GGUF_TYPE_ARRAY] = 0, // undefined +}; +static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); + +static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = { + [GGUF_TYPE_UINT8] = "u8", + [GGUF_TYPE_INT8] = "i8", + [GGUF_TYPE_UINT16] = "u16", + [GGUF_TYPE_INT16] = "i16", + [GGUF_TYPE_UINT32] = "u32", + [GGUF_TYPE_INT32] = "i32", + [GGUF_TYPE_FLOAT32] = "f32", + [GGUF_TYPE_BOOL] = "bool", + [GGUF_TYPE_STRING] = "str", + [GGUF_TYPE_ARRAY] = "arr", + [GGUF_TYPE_UINT64] = "u64", + [GGUF_TYPE_INT64] = "i64", + [GGUF_TYPE_FLOAT64] = "f64", +}; +static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); + +union gguf_value { + uint8_t uint8; + int8_t int8; + uint16_t uint16; + int16_t int16; + uint32_t uint32; + int32_t int32; + float float32; + uint64_t uint64; + int64_t int64; + double float64; + bool bool_; + + struct gguf_str str; + + struct { + enum gguf_type type; + + uint64_t n; // GGUFv2 + void * data; + } arr; +}; + +struct gguf_kv { + struct gguf_str key; + + enum gguf_type type; + union gguf_value value; +}; + +struct gguf_header { + char magic[4]; + + uint32_t version; + uint64_t n_tensors; // GGUFv2 + uint64_t n_kv; // GGUFv2 +}; + +struct gguf_tensor_info { + struct gguf_str name; + + uint32_t n_dims; + uint64_t ne[GGML_MAX_DIMS]; + + enum ggml_type type; + + uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT` + + // for writing API + const void * data; + size_t size; +}; + +struct gguf_context { + struct gguf_header header; + + struct gguf_kv * kv; + struct gguf_tensor_info * infos; + + size_t alignment; + size_t offset; // offset of `data` from beginning of file + size_t size; // size of `data` in bytes + + //uint8_t * padding; + void * data; +}; + +static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) { + const size_t n = fread(dst, 1, size, file); + *offset += n; + return n == size; +} + +static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) { + p->n = 0; + p->data = NULL; + + bool ok = true; + + ok = ok && gguf_fread_el(file, &p->n, sizeof(p->n), offset); p->data = calloc(p->n + 1, 1); + ok = ok && gguf_fread_el(file, p->data, p->n, offset); + + return ok; +} + +struct gguf_context * gguf_init_empty(void) { + struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); + + memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic)); + ctx->header.version = GGUF_VERSION; + ctx->header.n_tensors = 0; + ctx->header.n_kv = 0; + + ctx->kv = NULL; + ctx->infos = NULL; + + ctx->alignment = GGUF_DEFAULT_ALIGNMENT; + ctx->offset = 0; + ctx->size = 0; + + ctx->data = NULL; + + return ctx; +} + +struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { + FILE * file = fopen(fname, "rb"); + if (!file) { + return NULL; + } + + // offset from start of file + size_t offset = 0; + + char magic[4]; + + // check the magic before making allocations + { + gguf_fread_el(file, &magic, sizeof(magic), &offset); + + for (uint32_t i = 0; i < sizeof(magic); i++) { + if (magic[i] != GGUF_MAGIC[i]) { + fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]); + fclose(file); + return NULL; + } + } + } + + bool ok = true; + + struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); + + // read the header + { + strncpy(ctx->header.magic, magic, 4); + + ctx->kv = NULL; + ctx->infos = NULL; + ctx->data = NULL; + + ok = ok && gguf_fread_el(file, &ctx->header.version, sizeof(ctx->header.version), &offset); + ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset); + ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset); + + if (ctx->header.version == 1) { + fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__); + fclose(file); + gguf_free(ctx); + return NULL; + } + + if (!ok) { + fprintf(stderr, "%s: failed to read header\n", __func__); + fclose(file); + gguf_free(ctx); + return NULL; + } + } + + // read the kv pairs + { + ctx->kv = malloc(ctx->header.n_kv * sizeof(struct gguf_kv)); + + for (uint64_t i = 0; i < ctx->header.n_kv; ++i) { + struct gguf_kv * kv = &ctx->kv[i]; + + //fprintf(stderr, "%s: reading kv %d\n", __func__, i); + + ok = ok && gguf_fread_str(file, &kv->key, &offset); + ok = ok && gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset); + + //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data); + + switch (kv->type) { + case GGUF_TYPE_UINT8: ok = ok && gguf_fread_el (file, &kv->value.uint8, sizeof(kv->value.uint8), &offset); break; + case GGUF_TYPE_INT8: ok = ok && gguf_fread_el (file, &kv->value.int8, sizeof(kv->value.int8), &offset); break; + case GGUF_TYPE_UINT16: ok = ok && gguf_fread_el (file, &kv->value.uint16, sizeof(kv->value.uint16), &offset); break; + case GGUF_TYPE_INT16: ok = ok && gguf_fread_el (file, &kv->value.int16, sizeof(kv->value.int16), &offset); break; + case GGUF_TYPE_UINT32: ok = ok && gguf_fread_el (file, &kv->value.uint32, sizeof(kv->value.uint32), &offset); break; + case GGUF_TYPE_INT32: ok = ok && gguf_fread_el (file, &kv->value.int32, sizeof(kv->value.int32), &offset); break; + case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break; + case GGUF_TYPE_UINT64: ok = ok && gguf_fread_el (file, &kv->value.uint64, sizeof(kv->value.uint64), &offset); break; + case GGUF_TYPE_INT64: ok = ok && gguf_fread_el (file, &kv->value.int64, sizeof(kv->value.int64), &offset); break; + case GGUF_TYPE_FLOAT64: ok = ok && gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break; + case GGUF_TYPE_BOOL: ok = ok && gguf_fread_el (file, &kv->value.bool_, sizeof(kv->value.bool_), &offset); break; + case GGUF_TYPE_STRING: ok = ok && gguf_fread_str(file, &kv->value.str, &offset); break; + case GGUF_TYPE_ARRAY: + { + ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset); + ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset); + + switch (kv->value.arr.type) { + case GGUF_TYPE_UINT8: + case GGUF_TYPE_INT8: + case GGUF_TYPE_UINT16: + case GGUF_TYPE_INT16: + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: + case GGUF_TYPE_FLOAT32: + case GGUF_TYPE_UINT64: + case GGUF_TYPE_INT64: + case GGUF_TYPE_FLOAT64: + case GGUF_TYPE_BOOL: + { + kv->value.arr.data = malloc(kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type]); + ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type], &offset); + } break; + case GGUF_TYPE_STRING: + { + kv->value.arr.data = malloc(kv->value.arr.n * sizeof(struct gguf_str)); + for (uint64_t j = 0; j < kv->value.arr.n; ++j) { + ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset); + } + } break; + case GGUF_TYPE_ARRAY: + case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; + } + } break; + case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); + } + + if (!ok) { + break; + } + } + + if (!ok) { + fprintf(stderr, "%s: failed to read key-value pairs\n", __func__); + fclose(file); + gguf_free(ctx); + return NULL; + } + } + + // read the tensor infos + { + ctx->infos = malloc(ctx->header.n_tensors * sizeof(struct gguf_tensor_info)); + + for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info * info = &ctx->infos[i]; + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + info->ne[j] = 1; + } + + ok = ok && gguf_fread_str(file, &info->name, &offset); + ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims), &offset); + for (uint32_t j = 0; j < info->n_dims; ++j) { + ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset); + } + ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset); + ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset); + + if (!ok) { + fprintf(stderr, "%s: failed to read tensor info\n", __func__); + fclose(file); + gguf_free(ctx); + return NULL; + } + } + } + + ctx->alignment = GGUF_DEFAULT_ALIGNMENT; + + int alignment_idx = gguf_find_key(ctx, "general.alignment"); + if (alignment_idx != -1) { + ctx->alignment = gguf_get_val_u32(ctx, alignment_idx); + } + + // we require the data section to be aligned, so take into account any padding + { + const size_t offset_pad = offset % ctx->alignment; + + if (offset_pad != 0) { + offset += ctx->alignment - offset_pad; + fseek(file, offset, SEEK_SET); + } + } + + // store the current file offset - this is where the data section starts + ctx->offset = offset; + + // compute the total size of the data section, taking into account the alignment + { + ctx->size = 0; + for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info * info = &ctx->infos[i]; + + const int64_t ne = + (int64_t) info->ne[0] * + (int64_t) info->ne[1] * + (int64_t) info->ne[2] * + (int64_t) info->ne[3]; + + if (ne % ggml_blck_size(info->type) != 0) { + fprintf(stderr, "%s: tensor '%s' number of elements (%" PRId64 ") is not a multiple of block size (%d)\n", + __func__, info->name.data, ne, ggml_blck_size(info->type)); + fclose(file); + gguf_free(ctx); + return NULL; + } + + const size_t size_cur = ggml_row_size(info->type, ne); + + ctx->size += GGML_PAD(size_cur, ctx->alignment); + } + } + + // load the tensor data only if requested + if (params.ctx != NULL) { + // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob + // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of + // the ggml_tensor structs to the appropriate locations in the binary blob + + // compute the exact size needed for the new ggml_context + const size_t mem_size = + params.no_alloc ? + (ctx->header.n_tensors )*ggml_tensor_overhead() : + (ctx->header.n_tensors + 1)*ggml_tensor_overhead() + ctx->size; + + struct ggml_init_params pdata = { + .mem_size = mem_size, + .mem_buffer = NULL, + .no_alloc = params.no_alloc, + }; + + *params.ctx = ggml_init(pdata); + + struct ggml_context * ctx_data = *params.ctx; + + struct ggml_tensor * data = NULL; + + if (!params.no_alloc) { + data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size); + + ok = ok && data != NULL; + + // read the binary blob with the tensor data + ok = ok && gguf_fread_el(file, data->data, ctx->size, &offset); + + if (!ok) { + fprintf(stderr, "%s: failed to read tensor data\n", __func__); + fclose(file); + ggml_free(ctx_data); + gguf_free(ctx); + return NULL; + } + + ctx->data = data->data; + } + + ggml_set_no_alloc(ctx_data, true); + + // create the tensors + for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { + const int64_t ne[GGML_MAX_DIMS] = { + ctx->infos[i].ne[0], + ctx->infos[i].ne[1], + ctx->infos[i].ne[2], + ctx->infos[i].ne[3], + }; + + struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne); + + ok = ok && cur != NULL; + + ggml_set_name(cur, ctx->infos[i].name.data); + + if (!ok) { + break; + } + + // point the data member to the appropriate location in the binary blob using the tensor infos + if (!params.no_alloc) { + //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file + cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data + } + } + + if (!ok) { + fprintf(stderr, "%s: failed to read the tensor data\n", __func__); + fclose(file); + ggml_free(ctx_data); + gguf_free(ctx); + return NULL; + } + + ggml_set_no_alloc(ctx_data, params.no_alloc); + } + + fclose(file); + + return ctx; +} + +void gguf_free(struct gguf_context * ctx) { + if (ctx == NULL) { + return; + } + + if (ctx->kv) { + // free string memory - not great.. + for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { + struct gguf_kv * kv = &ctx->kv[i]; + + if (kv->key.data) { + free(kv->key.data); + } + + if (kv->type == GGUF_TYPE_STRING) { + if (kv->value.str.data) { + free(kv->value.str.data); + } + } + + if (kv->type == GGUF_TYPE_ARRAY) { + if (kv->value.arr.data) { + if (kv->value.arr.type == GGUF_TYPE_STRING) { + for (uint32_t j = 0; j < kv->value.arr.n; ++j) { + struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j]; + if (str->data) { + free(str->data); + } + } + } + free(kv->value.arr.data); + } + } + } + + free(ctx->kv); + } + + if (ctx->infos) { + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info * info = &ctx->infos[i]; + + if (info->name.data) { + free(info->name.data); + } + } + + free(ctx->infos); + } + + GGML_ALIGNED_FREE(ctx); +} + +const char * gguf_type_name(enum gguf_type type) { + return GGUF_TYPE_NAME[type]; +} + +int gguf_get_version(const struct gguf_context * ctx) { + return ctx->header.version; +} + +size_t gguf_get_alignment(const struct gguf_context * ctx) { + return ctx->alignment; +} + +size_t gguf_get_data_offset(const struct gguf_context * ctx) { + return ctx->offset; +} + +void * gguf_get_data(const struct gguf_context * ctx) { + return ctx->data; +} + +int gguf_get_n_kv(const struct gguf_context * ctx) { + return ctx->header.n_kv; +} + +int gguf_find_key(const struct gguf_context * ctx, const char * key) { + // return -1 if key not found + int keyfound = -1; + + const int n_kv = gguf_get_n_kv(ctx); + + for (int i = 0; i < n_kv; ++i) { + if (strcmp(key, gguf_get_key(ctx, i)) == 0) { + keyfound = i; + break; + } + } + + return keyfound; +} + +const char * gguf_get_key(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + return ctx->kv[key_id].key.data; +} + +enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + return ctx->kv[key_id].type; +} + +enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); + return ctx->kv[key_id].value.arr.type; +} + +const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); + return ctx->kv[key_id].value.arr.data; +} + +const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); + struct gguf_kv * kv = &ctx->kv[key_id]; + struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i]; + return str->data; +} + +int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); + return ctx->kv[key_id].value.arr.n; +} + +uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8); + return ctx->kv[key_id].value.uint8; +} + +int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8); + return ctx->kv[key_id].value.int8; +} + +uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16); + return ctx->kv[key_id].value.uint16; +} + +int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16); + return ctx->kv[key_id].value.int16; +} + +uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32); + return ctx->kv[key_id].value.uint32; +} + +int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32); + return ctx->kv[key_id].value.int32; +} + +float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32); + return ctx->kv[key_id].value.float32; +} + +uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64); + return ctx->kv[key_id].value.uint64; +} + +int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64); + return ctx->kv[key_id].value.int64; +} + +double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64); + return ctx->kv[key_id].value.float64; +} + +bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL); + return ctx->kv[key_id].value.bool_; +} + +const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING); + return ctx->kv[key_id].value.str.data; +} + +const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY); + GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING); + return &ctx->kv[key_id].value; +} + +int gguf_get_n_tensors(const struct gguf_context * ctx) { + return ctx->header.n_tensors; +} + +int gguf_find_tensor(const struct gguf_context * ctx, const char * name) { + // return -1 if tensor not found + int tensorfound = -1; + + const int n_tensors = gguf_get_n_tensors(ctx); + + for (int i = 0; i < n_tensors; ++i) { + if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) { + tensorfound = i; + break; + } + } + + return tensorfound; +} + +size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) { + return ctx->infos[i].offset; +} + +char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) { + return ctx->infos[i].name.data; +} + +// returns the index +static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) { + const int idx = gguf_find_key(ctx, key); + if (idx >= 0) { + return idx; + } + + const int n_kv = gguf_get_n_kv(ctx); + + ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv)); + ctx->kv[n_kv].key.n = strlen(key); + ctx->kv[n_kv].key.data = strdup(key); + ctx->header.n_kv++; + + return n_kv; +} + +void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_UINT8; + ctx->kv[idx].value.uint8 = val; +} + +void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_INT8; + ctx->kv[idx].value.int8 = val; +} + +void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_UINT16; + ctx->kv[idx].value.uint16 = val; +} + +void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_INT16; + ctx->kv[idx].value.int16 = val; +} + +void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_UINT32; + ctx->kv[idx].value.uint32 = val; +} + +void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_INT32; + ctx->kv[idx].value.int32 = val; +} + +void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_FLOAT32; + ctx->kv[idx].value.float32 = val; +} + +void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_UINT64; + ctx->kv[idx].value.uint64 = val; +} + +void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_INT64; + ctx->kv[idx].value.int64 = val; +} + +void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_FLOAT64; + ctx->kv[idx].value.float64 = val; +} + +void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_BOOL; + ctx->kv[idx].value.bool_ = val; +} + +void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_STRING; + ctx->kv[idx].value.str.n = strlen(val); + ctx->kv[idx].value.str.data = strdup(val); +} + +void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_ARRAY; + ctx->kv[idx].value.arr.type = type; + ctx->kv[idx].value.arr.n = n; + ctx->kv[idx].value.arr.data = malloc(n*GGUF_TYPE_SIZE[type]); + memcpy(ctx->kv[idx].value.arr.data, data, n*GGUF_TYPE_SIZE[type]); +} + +void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, int n) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_ARRAY; + ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING; + ctx->kv[idx].value.arr.n = n; + ctx->kv[idx].value.arr.data = malloc(n*sizeof(struct gguf_str)); + for (int i = 0; i < n; i++) { + struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i]; + str->n = strlen(data[i]); + str->data = strdup(data[i]); + } +} + +// set or add KV pairs from another context +void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) { + for (uint32_t i = 0; i < src->header.n_kv; i++) { + switch (src->kv[i].type) { + case GGUF_TYPE_UINT8: gguf_set_val_u8 (ctx, src->kv[i].key.data, src->kv[i].value.uint8); break; + case GGUF_TYPE_INT8: gguf_set_val_i8 (ctx, src->kv[i].key.data, src->kv[i].value.int8); break; + case GGUF_TYPE_UINT16: gguf_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16); break; + case GGUF_TYPE_INT16: gguf_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16); break; + case GGUF_TYPE_UINT32: gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32); break; + case GGUF_TYPE_INT32: gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32); break; + case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32); break; + case GGUF_TYPE_UINT64: gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64); break; + case GGUF_TYPE_INT64: gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64); break; + case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64); break; + case GGUF_TYPE_BOOL: gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_); break; + case GGUF_TYPE_STRING: gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break; + case GGUF_TYPE_ARRAY: + { + if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) { + const char ** data = malloc(src->kv[i].value.arr.n*sizeof(char *)); + for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) { + data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data; + } + gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n); + free(data); + } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) { + GGML_ASSERT(false && "nested arrays not supported"); + } else { + gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n); + } + } break; + case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; + } + } +} + +void gguf_add_tensor( + struct gguf_context * ctx, + const struct ggml_tensor * tensor) { + const int idx = ctx->header.n_tensors; + ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info)); + + ctx->infos[idx].name.n = strlen(tensor->name); + ctx->infos[idx].name.data = strdup(tensor->name); + + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + ctx->infos[idx].ne[i] = 1; + } + + ctx->infos[idx].n_dims = ggml_n_dims(tensor); + for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) { + ctx->infos[idx].ne[i] = tensor->ne[i]; + } + + ctx->infos[idx].type = tensor->type; + ctx->infos[idx].offset = 0; + ctx->infos[idx].data = tensor->data; + ctx->infos[idx].size = ggml_nbytes(tensor); + + if (ctx->header.n_tensors > 0) { + ctx->infos[idx].offset = ctx->infos[idx - 1].offset + GGML_PAD(ctx->infos[idx - 1].size, ctx->alignment); + } + + ctx->header.n_tensors++; +} + +void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) { + const int idx = gguf_find_tensor(ctx, name); + if (idx < 0) { + GGML_ASSERT(false && "tensor not found"); + } + + ctx->infos[idx].type = type; +} + +void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size) { + const int idx = gguf_find_tensor(ctx, name); + if (idx < 0) { + GGML_ASSERT(false && "tensor not found"); + } + + ctx->infos[idx].data = data; + ctx->infos[idx].size = size; + + // update offsets + for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) { + ctx->infos[i].offset = ctx->infos[i - 1].offset + GGML_PAD(ctx->infos[i - 1].size, ctx->alignment); + } +} + +//static void gguf_fwrite_str(FILE * file, const struct gguf_str * val) { +// fwrite(&val->n, sizeof(val->n), 1, file); +// fwrite(val->data, sizeof(char), val->n, file); +//} +// +//static void gguf_fwrite_el(FILE * file, const void * val, size_t size) { +// fwrite(val, sizeof(char), size, file); +//} + +struct gguf_buf { + void * data; + size_t size; + size_t offset; +}; + +static struct gguf_buf gguf_buf_init(size_t size) { + struct gguf_buf buf = { + /*buf.data =*/ size == 0 ? NULL : malloc(size), + /*buf.size =*/ size, + /*buf.offset =*/ 0, + }; + + return buf; +} + +static void gguf_buf_free(struct gguf_buf buf) { + if (buf.data) { + free(buf.data); + } +} + +static void gguf_buf_grow(struct gguf_buf * buf, size_t size) { + if (buf->offset + size > buf->size) { + buf->size = 1.5*(buf->offset + size); + if (buf->data) { + buf->data = realloc(buf->data, buf->size); + } + } +} + +static void gguf_bwrite_str(struct gguf_buf * buf, const struct gguf_str * val) { + gguf_buf_grow(buf, sizeof(val->n) + val->n); + + if (buf->data) { + memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n)); + } + buf->offset += sizeof(val->n); + + if (buf->data) { + memcpy((char *) buf->data + buf->offset, val->data, val->n); + } + buf->offset += val->n; +} + +static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_size) { + gguf_buf_grow(buf, el_size); + + if (buf->data) { + memcpy((char *) buf->data + buf->offset, val, el_size); + } + buf->offset += el_size; +} + +static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) { + // write header + gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic)); + gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version)); + gguf_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors)); + gguf_bwrite_el(buf, &ctx->header.n_kv, sizeof(ctx->header.n_kv)); + + // write key-value pairs + for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { + struct gguf_kv * kv = &ctx->kv[i]; + + gguf_bwrite_str(buf, &kv->key); + gguf_bwrite_el (buf, &kv->type, sizeof(kv->type)); + + switch (kv->type) { + case GGUF_TYPE_UINT8: gguf_bwrite_el( buf, &kv->value.uint8, sizeof(kv->value.uint8) ); break; + case GGUF_TYPE_INT8: gguf_bwrite_el (buf, &kv->value.int8, sizeof(kv->value.int8) ); break; + case GGUF_TYPE_UINT16: gguf_bwrite_el (buf, &kv->value.uint16, sizeof(kv->value.uint16) ); break; + case GGUF_TYPE_INT16: gguf_bwrite_el (buf, &kv->value.int16, sizeof(kv->value.int16) ); break; + case GGUF_TYPE_UINT32: gguf_bwrite_el (buf, &kv->value.uint32, sizeof(kv->value.uint32) ); break; + case GGUF_TYPE_INT32: gguf_bwrite_el (buf, &kv->value.int32, sizeof(kv->value.int32) ); break; + case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break; + case GGUF_TYPE_UINT64: gguf_bwrite_el (buf, &kv->value.uint64, sizeof(kv->value.uint64) ); break; + case GGUF_TYPE_INT64: gguf_bwrite_el (buf, &kv->value.int64, sizeof(kv->value.int64) ); break; + case GGUF_TYPE_FLOAT64: gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break; + case GGUF_TYPE_BOOL: gguf_bwrite_el (buf, &kv->value.bool_, sizeof(kv->value.bool_) ); break; + case GGUF_TYPE_STRING: gguf_bwrite_str(buf, &kv->value.str ); break; + case GGUF_TYPE_ARRAY: + { + gguf_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type)); + gguf_bwrite_el(buf, &kv->value.arr.n, sizeof(kv->value.arr.n) ); + + switch (kv->value.arr.type) { + case GGUF_TYPE_UINT8: + case GGUF_TYPE_INT8: + case GGUF_TYPE_UINT16: + case GGUF_TYPE_INT16: + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: + case GGUF_TYPE_FLOAT32: + case GGUF_TYPE_UINT64: + case GGUF_TYPE_INT64: + case GGUF_TYPE_FLOAT64: + case GGUF_TYPE_BOOL: + { + gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type]); + } break; + case GGUF_TYPE_STRING: + { + for (uint32_t j = 0; j < kv->value.arr.n; ++j) { + gguf_bwrite_str(buf, &((struct gguf_str *) kv->value.arr.data)[j]); + } + } break; + case GGUF_TYPE_ARRAY: + case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; + } + } break; + case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); + } + } + + // write tensor infos + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info * info = &ctx->infos[i]; + + gguf_bwrite_str(buf, &info->name); + gguf_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims)); + for (uint32_t j = 0; j < info->n_dims; ++j) { + gguf_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j])); + } + gguf_bwrite_el(buf, &info->type, sizeof(info->type)); + gguf_bwrite_el(buf, &info->offset, sizeof(info->offset)); + } + + // we require the data section to be aligned, so take into account any padding + { + const size_t offset = buf->offset; + const size_t offset_pad = GGML_PAD(offset, ctx->alignment); + + if (offset_pad != offset) { + uint8_t pad = 0; + for (size_t i = 0; i < offset_pad - offset; ++i) { + gguf_bwrite_el(buf, &pad, sizeof(pad)); + } + } + } + + if (only_meta) { + return; + } + + size_t offset = 0; + + // write tensor data + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info * info = &ctx->infos[i]; + + const size_t size = info->size; + const size_t size_pad = GGML_PAD(size, ctx->alignment); + + gguf_bwrite_el(buf, info->data, size); + + if (size_pad != size) { + uint8_t pad = 0; + for (size_t j = 0; j < size_pad - size; ++j) { + gguf_bwrite_el(buf, &pad, sizeof(pad)); + } + } + + GGML_ASSERT(offset == info->offset); + + offset += size_pad; + } +} + +void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { + FILE * file = fopen(fname, "wb"); + if (!file) { + GGML_ASSERT(false && "failed to open file for writing"); + } + + struct gguf_buf buf = gguf_buf_init(16*1024); + + gguf_write_to_buf(ctx, &buf, only_meta); + + fwrite(buf.data, 1, buf.offset, file); + + gguf_buf_free(buf); + + fclose(file); +} + +size_t gguf_get_meta_size(const struct gguf_context * ctx) { + // no allocs - only compute size + struct gguf_buf buf = gguf_buf_init(0); + + gguf_write_to_buf(ctx, &buf, true); + + return buf.offset; +} + +void gguf_get_meta_data(const struct gguf_context * ctx, void * data) { + struct gguf_buf buf = gguf_buf_init(16*1024); + + gguf_write_to_buf(ctx, &buf, true); + + memcpy(data, buf.data, buf.offset); + + gguf_buf_free(buf); +} + +//////////////////////////////////////////////////////////////////////////////// + +int ggml_cpu_has_avx(void) { +#if defined(__AVX__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx2(void) { +#if defined(__AVX2__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx512(void) { +#if defined(__AVX512F__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx512_vbmi(void) { +#if defined(__AVX512VBMI__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx512_vnni(void) { +#if defined(__AVX512VNNI__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_fma(void) { +#if defined(__FMA__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_neon(void) { +#if defined(__ARM_NEON) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_arm_fma(void) { +#if defined(__ARM_FEATURE_FMA) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_metal(void) { +#if defined(GGML_USE_METAL) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_f16c(void) { +#if defined(__F16C__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_fp16_va(void) { +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_wasm_simd(void) { +#if defined(__wasm_simd128__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_blas(void) { +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_cublas(void) { +#if defined(GGML_USE_CUBLAS) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_clblast(void) { +#if defined(GGML_USE_CLBLAST) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_gpublas(void) { + return ggml_cpu_has_cublas() || ggml_cpu_has_clblast(); +} + +int ggml_cpu_has_sse3(void) { +#if defined(__SSE3__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_ssse3(void) { +#if defined(__SSSE3__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_vsx(void) { +#if defined(__POWER9_VECTOR__) + return 1; +#else + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////// diff --git a/applications/src/llama/ggml.h b/applications/src/llama/ggml.h new file mode 100644 index 000000000..68f7833b6 --- /dev/null +++ b/applications/src/llama/ggml.h @@ -0,0 +1,2234 @@ +#pragma once + +// +// GGML Tensor Library +// +// This documentation is still a work in progress. +// If you wish some specific topics to be covered, feel free to drop a comment: +// +// https://github.com/ggerganov/whisper.cpp/issues/40 +// +// ## Overview +// +// This library implements: +// +// - a set of tensor operations +// - automatic differentiation +// - basic optimization algorithms +// +// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes, +// but is not limited to, the following: +// +// - linear regression +// - support vector machines +// - neural networks +// +// The library allows the user to define a certain function using the available tensor operations. This function +// definition is represented internally via a computation graph. Each tensor operation in the function definition +// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the +// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized +// using one of the available optimization algorithms. +// +// For example, here we define the function: f(x) = a*x^2 + b +// +// { +// struct ggml_init_params params = { +// .mem_size = 16*1024*1024, +// .mem_buffer = NULL, +// }; +// +// // memory allocation happens here +// struct ggml_context * ctx = ggml_init(params); +// +// struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// +// ggml_set_param(ctx, x); // x is an input variable +// +// struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * x2 = ggml_mul(ctx, x, x); +// struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b); +// +// ... +// } +// +// Notice that the function definition above does not involve any actual computation. The computation is performed only +// when the user explicitly requests it. For example, to compute the function's value at x = 2.0: +// +// { +// ... +// +// struct ggml_cgraph * gf = ggml_new_graph(ctx); +// ggml_build_forward_expand(gf, f); +// +// // set the input variable and parameter values +// ggml_set_f32(x, 2.0f); +// ggml_set_f32(a, 3.0f); +// ggml_set_f32(b, 4.0f); +// +// ggml_graph_compute_with_ctx(ctx, &gf, n_threads); +// +// printf("f = %f\n", ggml_get_f32_1d(f, 0)); +// +// ... +// } +// +// The actual computation is performed in the ggml_graph_compute() function. +// +// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the +// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know +// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory +// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was +// actually needed. +// +// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic +// differentiation and optimization algorithms. +// +// The described approach allows to define the function graph once and then compute its forward or backward graphs +// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way +// the user can avoid the memory allocation overhead at runtime. +// +// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class +// citizens, but in theory the library can be extended to support FP8 and integer data types. +// +// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary +// and binary operations. Most of the available operations fall into one of these two categories. With time, it became +// clear that the library needs to support more complex operations. The way to support these operations is not clear +// yet, but a few examples are demonstrated in the following operations: +// +// - ggml_permute() +// - ggml_conv_1d_1s() +// - ggml_conv_1d_2s() +// +// For each tensor operator, the library implements a forward and backward computation function. The forward function +// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the +// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a +// calculus class, or watch the following video: +// +// What is Automatic Differentiation? +// https://www.youtube.com/watch?v=wG_nF1awSSY +// +// +// ## Tensor data (struct ggml_tensor) +// +// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of +// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains +// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example: +// +// { +// struct ggml_tensor * c = ggml_add(ctx, a, b); +// +// assert(c->src[0] == a); +// assert(c->src[1] == b); +// } +// +// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the +// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows +// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and +// permutation. All tensor operations have to take the stride into account and not assume that the tensor is +// contiguous in memory. +// +// The data of the tensor is accessed via the "data" pointer. For example: +// +// { +// const int nx = 2; +// const int ny = 3; +// +// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, ny); +// +// for (int y = 0; y < ny; y++) { +// for (int x = 0; x < nx; x++) { +// *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y; +// } +// } +// +// ... +// } +// +// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used. +// +// ## The matrix multiplication operator (ggml_mul_mat) +// +// TODO +// +// +// ## Multi-threading +// +// TODO +// +// +// ## Overview of ggml.c +// +// TODO +// +// +// ## SIMD optimizations +// +// TODO +// +// +// ## Debugging ggml +// +// TODO +// +// + +#ifdef GGML_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef GGML_BUILD +# define GGML_API __declspec(dllexport) +# else +# define GGML_API __declspec(dllimport) +# endif +# else +# define GGML_API __attribute__ ((visibility ("default"))) +# endif +#else +# define GGML_API +#endif + +// TODO: support for clang +#ifdef __GNUC__ +# define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define GGML_DEPRECATED(func, hint) func +#endif + +#ifndef __GNUC__ +# define GGML_ATTRIBUTE_FORMAT(...) +#elif defined(__MINGW32__) +# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif + +#include +#include +#include + +#define GGML_FILE_MAGIC 0x67676d6c // "ggml" +#define GGML_FILE_VERSION 1 + +#define GGML_QNT_VERSION 2 // bump this on quantization format changes +#define GGML_QNT_VERSION_FACTOR 1000 // do not change this + +#define GGML_MAX_DIMS 4 +#define GGML_MAX_PARAMS 2048 +#define GGML_MAX_CONTEXTS 64 +#define GGML_MAX_SRC 10 +#define GGML_MAX_NAME 64 +#define GGML_MAX_OP_PARAMS 64 +#define GGML_DEFAULT_N_THREADS 4 +#define GGML_DEFAULT_GRAPH_SIZE 2048 +#if UINTPTR_MAX == 0xFFFFFFFF + #define GGML_MEM_ALIGN 4 +#else + #define GGML_MEM_ALIGN 16 +#endif + +#define GGML_EXIT_SUCCESS 0 +#define GGML_EXIT_ABORTED 1 + +#define GGUF_MAGIC "GGUF" + +#define GGUF_VERSION 3 + +#define GGUF_DEFAULT_ALIGNMENT 32 + +#define GGML_UNUSED(x) (void)(x) + +#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) + +#define GGML_ASSERT(x) \ + do { \ + if (!(x)) { \ + fflush(stdout); \ + fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + ggml_print_backtrace(); \ + abort(); \ + } \ + } while (0) + +#ifndef NDEBUG +#define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached") +#elif defined(__GNUC__) +#define GGML_UNREACHABLE() __builtin_unreachable() +#else +#define GGML_UNREACHABLE() ((void) 0) +#endif + +// used to copy the number of elements and stride in bytes of tensors into local variables. +// main purpose is to reduce code duplication and improve readability. +// +// example: +// +// GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); +// GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); +// +#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \ + const type prefix##0 = (pointer)->array[0]; \ + GGML_UNUSED(prefix##0); +#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \ + GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \ + const type prefix##1 = (pointer)->array[1]; \ + GGML_UNUSED(prefix##1); +#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \ + GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \ + const type prefix##2 = (pointer)->array[2]; \ + GGML_UNUSED(prefix##2); +#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \ + GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \ + const type prefix##3 = (pointer)->array[3]; \ + GGML_UNUSED(prefix##3); + +#define GGML_TENSOR_UNARY_OP_LOCALS \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + +#define GGML_TENSOR_BINARY_OP_LOCALS \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(__ARM_NEON) && defined(__CUDACC__) + typedef half ggml_fp16_t; +#elif defined(__ARM_NEON) + typedef __fp16 ggml_fp16_t; +#else + typedef uint16_t ggml_fp16_t; +#endif + + // convert FP16 <-> FP32 + GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x); + GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x); + + GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n); + GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n); + + struct ggml_object; + struct ggml_context; + + enum ggml_type { + GGML_TYPE_F32 = 0, + GGML_TYPE_F16 = 1, + GGML_TYPE_Q4_0 = 2, + GGML_TYPE_Q4_1 = 3, + // GGML_TYPE_Q4_2 = 4, support has been removed + // GGML_TYPE_Q4_3 (5) support has been removed + GGML_TYPE_Q5_0 = 6, + GGML_TYPE_Q5_1 = 7, + GGML_TYPE_Q8_0 = 8, + GGML_TYPE_Q8_1 = 9, + // k-quantizations + GGML_TYPE_Q2_K = 10, + GGML_TYPE_Q3_K = 11, + GGML_TYPE_Q4_K = 12, + GGML_TYPE_Q5_K = 13, + GGML_TYPE_Q6_K = 14, + GGML_TYPE_Q8_K = 15, + GGML_TYPE_I8, + GGML_TYPE_I16, + GGML_TYPE_I32, + GGML_TYPE_COUNT, + }; + + enum ggml_backend_type { + GGML_BACKEND_CPU = 0, + GGML_BACKEND_GPU = 10, + GGML_BACKEND_GPU_SPLIT = 20, + }; + + // model file types + enum ggml_ftype { + GGML_FTYPE_UNKNOWN = -1, + GGML_FTYPE_ALL_F32 = 0, + GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 + GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors + GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors + GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors + }; + + // available tensor operations: + enum ggml_op { + GGML_OP_NONE = 0, + + GGML_OP_DUP, + GGML_OP_ADD, + GGML_OP_ADD1, + GGML_OP_ACC, + GGML_OP_SUB, + GGML_OP_MUL, + GGML_OP_DIV, + GGML_OP_SQR, + GGML_OP_SQRT, + GGML_OP_LOG, + GGML_OP_SUM, + GGML_OP_SUM_ROWS, + GGML_OP_MEAN, + GGML_OP_ARGMAX, + GGML_OP_REPEAT, + GGML_OP_REPEAT_BACK, + GGML_OP_CONCAT, + GGML_OP_SILU_BACK, + GGML_OP_NORM, // normalize + GGML_OP_RMS_NORM, + GGML_OP_RMS_NORM_BACK, + GGML_OP_GROUP_NORM, + + GGML_OP_MUL_MAT, + GGML_OP_MUL_MAT_ID, + GGML_OP_OUT_PROD, + + GGML_OP_SCALE, + GGML_OP_SET, + GGML_OP_CPY, + GGML_OP_CONT, + GGML_OP_RESHAPE, + GGML_OP_VIEW, + GGML_OP_PERMUTE, + GGML_OP_TRANSPOSE, + GGML_OP_GET_ROWS, + GGML_OP_GET_ROWS_BACK, + GGML_OP_DIAG, + GGML_OP_DIAG_MASK_INF, + GGML_OP_DIAG_MASK_ZERO, + GGML_OP_SOFT_MAX, + GGML_OP_SOFT_MAX_BACK, + GGML_OP_ROPE, + GGML_OP_ROPE_BACK, + GGML_OP_ALIBI, + GGML_OP_CLAMP, + GGML_OP_CONV_TRANSPOSE_1D, + GGML_OP_IM2COL, + GGML_OP_CONV_TRANSPOSE_2D, + GGML_OP_POOL_1D, + GGML_OP_POOL_2D, + GGML_OP_UPSCALE, // nearest interpolate + GGML_OP_PAD, + GGML_OP_ARGSORT, + GGML_OP_LEAKY_RELU, + + GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_FF, + GGML_OP_FLASH_ATTN_BACK, + GGML_OP_WIN_PART, + GGML_OP_WIN_UNPART, + GGML_OP_GET_REL_POS, + GGML_OP_ADD_REL_POS, + + GGML_OP_UNARY, + + GGML_OP_MAP_UNARY, + GGML_OP_MAP_BINARY, + + GGML_OP_MAP_CUSTOM1_F32, + GGML_OP_MAP_CUSTOM2_F32, + GGML_OP_MAP_CUSTOM3_F32, + + GGML_OP_MAP_CUSTOM1, + GGML_OP_MAP_CUSTOM2, + GGML_OP_MAP_CUSTOM3, + + GGML_OP_CROSS_ENTROPY_LOSS, + GGML_OP_CROSS_ENTROPY_LOSS_BACK, + + GGML_OP_COUNT, + }; + + enum ggml_unary_op { + GGML_UNARY_OP_ABS, + GGML_UNARY_OP_SGN, + GGML_UNARY_OP_NEG, + GGML_UNARY_OP_STEP, + GGML_UNARY_OP_TANH, + GGML_UNARY_OP_ELU, + GGML_UNARY_OP_RELU, + GGML_UNARY_OP_GELU, + GGML_UNARY_OP_GELU_QUICK, + GGML_UNARY_OP_SILU, + + GGML_UNARY_OP_COUNT, + }; + + enum ggml_object_type { + GGML_OBJECT_TENSOR, + GGML_OBJECT_GRAPH, + GGML_OBJECT_WORK_BUFFER + }; + + enum ggml_log_level { + GGML_LOG_LEVEL_ERROR = 2, + GGML_LOG_LEVEL_WARN = 3, + GGML_LOG_LEVEL_INFO = 4 + }; + + // ggml object + struct ggml_object { + size_t offs; + size_t size; + + struct ggml_object * next; + + enum ggml_object_type type; + + char padding[4]; + }; + + static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object); + + // n-dimensional tensor + struct ggml_tensor { + enum ggml_type type; + enum ggml_backend_type backend; + + struct ggml_backend_buffer * buffer; + + int64_t ne[GGML_MAX_DIMS]; // number of elements + size_t nb[GGML_MAX_DIMS]; // stride in bytes: + // nb[0] = ggml_type_size(type) + // nb[1] = nb[0] * (ne[0] / ggml_blck_size(type)) + padding + // nb[i] = nb[i-1] * ne[i-1] + + // compute data + enum ggml_op op; + + // op params - allocated as int32_t for alignment + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; + + bool is_param; + + struct ggml_tensor * grad; + struct ggml_tensor * src[GGML_MAX_SRC]; + + // performance + int perf_runs; + int64_t perf_cycles; + int64_t perf_time_us; + + struct ggml_tensor * view_src; + size_t view_offs; + + void * data; + + char name[GGML_MAX_NAME]; + + void * extra; // extra things e.g. for ggml-cuda.cu + + char padding[8]; + }; + + static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); + + // the compute plan that needs to be prepared for ggml_graph_compute() + // since https://github.com/ggerganov/ggml/issues/287 + struct ggml_cplan { + size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()` + uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()` + + int n_threads; + + // abort ggml_graph_compute when true + bool (*abort_callback)(void * data); + void * abort_callback_data; + }; + + enum ggml_cgraph_eval_order { + GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0, + GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT, + GGML_CGRAPH_EVAL_ORDER_COUNT + }; + + struct ggml_hash_set { + size_t size; + struct ggml_tensor ** keys; + }; + + // computation graph + struct ggml_cgraph { + int size; + int n_nodes; + int n_leafs; + + struct ggml_tensor ** nodes; + struct ggml_tensor ** grads; + struct ggml_tensor ** leafs; + + struct ggml_hash_set visited_hash_table; + + enum ggml_cgraph_eval_order order; + + // performance + int perf_runs; + int64_t perf_cycles; + int64_t perf_time_us; + }; + + // scratch buffer + struct ggml_scratch { + size_t offs; + size_t size; + void * data; + }; + + struct ggml_init_params { + // memory pool + size_t mem_size; // bytes + void * mem_buffer; // if NULL, memory will be allocated internally + bool no_alloc; // don't allocate memory for the tensor data + }; + + + // compute types + + // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled. + // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995. + enum ggml_task_type { + GGML_TASK_INIT = 0, + GGML_TASK_COMPUTE, + GGML_TASK_FINALIZE, + }; + + struct ggml_compute_params { + enum ggml_task_type type; + + // ith = thread index, nth = number of threads + int ith, nth; + + // work buffer for all threads + size_t wsize; + void * wdata; + }; + + // misc + + GGML_API void ggml_time_init(void); // call this once at the beginning of the program + GGML_API int64_t ggml_time_ms(void); + GGML_API int64_t ggml_time_us(void); + GGML_API int64_t ggml_cycles(void); + GGML_API int64_t ggml_cycles_per_ms(void); + + GGML_API void ggml_print_backtrace(void); + + GGML_API void ggml_numa_init(void); // call once for better performance on NUMA systems + GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node + + GGML_API void ggml_print_object (const struct ggml_object * obj); + GGML_API void ggml_print_objects(const struct ggml_context * ctx); + + GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor); + GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor); + GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor); + GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN + + GGML_API int ggml_blck_size(enum ggml_type type); + GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block + GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row + + GGML_DEPRECATED( + GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float + "use ggml_row_size() instead"); + + GGML_API const char * ggml_type_name(enum ggml_type type); + GGML_API const char * ggml_op_name (enum ggml_op op); + GGML_API const char * ggml_op_symbol(enum ggml_op op); + + GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op); + GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name + + GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); + + GGML_API bool ggml_is_quantized(enum ggml_type type); + + // TODO: temporary until model loading of ggml examples is refactored + GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype); + + GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor); + GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor); + GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor); + GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars + + GGML_API bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1); + + // use this to compute the memory overhead of a tensor + GGML_API size_t ggml_tensor_overhead(void); + + // main + + GGML_API struct ggml_context * ggml_init(struct ggml_init_params params); + GGML_API void ggml_free(struct ggml_context * ctx); + + GGML_API size_t ggml_used_mem(const struct ggml_context * ctx); + + GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch); + GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx); + GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc); + + GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx); + GGML_API size_t ggml_get_mem_size (const struct ggml_context * ctx); + GGML_API size_t ggml_get_max_tensor_size(const struct ggml_context * ctx); + + GGML_API struct ggml_tensor * ggml_new_tensor( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int64_t *ne); + + GGML_API struct ggml_tensor * ggml_new_tensor_1d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0); + + GGML_API struct ggml_tensor * ggml_new_tensor_2d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1); + + GGML_API struct ggml_tensor * ggml_new_tensor_3d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + GGML_API struct ggml_tensor * ggml_new_tensor_4d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + + GGML_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); + GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); + + GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); + GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src); + + // Context tensor enumeration and lookup + GGML_API struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx); + GGML_API struct ggml_tensor * ggml_get_next_tensor (struct ggml_context * ctx, struct ggml_tensor * tensor); + GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name); + + GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); + GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); + GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); + + // Converts a flat index into coordinates + GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); + + GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); + GGML_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); + + GGML_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); + GGML_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value); + + GGML_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); + GGML_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); + + GGML_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); + GGML_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value); + + GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); + GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); + + GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor); + + GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor); + GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name); + GGML_ATTRIBUTE_FORMAT(2, 3) + GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...); + + // + // operations on tensors with backpropagation + // + + GGML_API struct ggml_tensor * ggml_dup( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_dup_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_add_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_add_cast( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type); + + GGML_API struct ggml_tensor * ggml_add1( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_add1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // dst = a + // view(dst, nb1, nb2, nb3, offset) += b + // return dst + GGML_API struct ggml_tensor * ggml_acc( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + GGML_API struct ggml_tensor * ggml_acc_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + GGML_API struct ggml_tensor * ggml_sub( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_sub_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_mul( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_mul_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_div( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_div_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sqr_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sqrt( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sqrt_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_log( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_log_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // return scalar + GGML_API struct ggml_tensor * ggml_sum( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d] + GGML_API struct ggml_tensor * ggml_sum_rows( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // mean along rows + GGML_API struct ggml_tensor * ggml_mean( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // argmax along rows + GGML_API struct ggml_tensor * ggml_argmax( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // if a is the same shape as b, and a is not parameter, return a + // otherwise, return a new tensor: repeat(a) to fit in b + GGML_API struct ggml_tensor * ggml_repeat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // sums repetitions in a into shape of b + GGML_API struct ggml_tensor * ggml_repeat_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // concat a and b on dim 2 + // used in stable-diffusion + GGML_API struct ggml_tensor * ggml_concat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_abs( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_abs_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sgn( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sgn_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_neg( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_neg_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_step( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_step_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_tanh( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_tanh_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_elu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_elu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_relu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_leaky_relu( + struct ggml_context * ctx, + struct ggml_tensor * a, float negative_slope, bool inplace); + + GGML_API struct ggml_tensor * ggml_relu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_quick_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_silu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_silu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // a - x + // b - dy + GGML_API struct ggml_tensor * ggml_silu_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // normalize along rows + GGML_API struct ggml_tensor * ggml_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_rms_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + // group normalize along ne0*ne1*n_groups + // used in stable-diffusion + // TODO: eps is hardcoded to 1e-6 for now + GGML_API struct ggml_tensor * ggml_group_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups); + + GGML_API struct ggml_tensor * ggml_group_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups); + + // a - x + // b - dy + GGML_API struct ggml_tensor * ggml_rms_norm_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps); + + // A: k columns, n rows => [ne03, ne02, n, k] + // B: k columns, m rows (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k] + // result is n columns, m rows => [ne03 * x, ne02 * y, m, n] + GGML_API struct ggml_tensor * ggml_mul_mat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // indirect matrix multiplication + // ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b) + GGML_API struct ggml_tensor * ggml_mul_mat_id( + struct ggml_context * ctx, + struct ggml_tensor * const as[], + int n_as, + struct ggml_tensor * ids, + int id, + struct ggml_tensor * b); + + // A: m columns, n rows, + // B: p columns, n rows, + // result is m columns, p rows + GGML_API struct ggml_tensor * ggml_out_prod( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // + // operations on tensors without backpropagation + // + + GGML_API struct ggml_tensor * ggml_scale( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_scale_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // b -> view(a,offset,nb1,nb2,3), return modified a + GGML_API struct ggml_tensor * ggml_set( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + // b -> view(a,offset,nb1,nb2,3), return view(a) + GGML_API struct ggml_tensor * ggml_set_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + GGML_API struct ggml_tensor * ggml_set_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t offset); + + GGML_API struct ggml_tensor * ggml_set_1d_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t offset); + + // b -> view(a,offset,nb1,nb2,3), return modified a + GGML_API struct ggml_tensor * ggml_set_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t offset); + + // b -> view(a,offset,nb1,nb2,3), return view(a) + GGML_API struct ggml_tensor * ggml_set_2d_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t offset); + + // a -> b, return view(b) + GGML_API struct ggml_tensor * ggml_cpy( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // a -> b, in-place, return view(b) + GGML_API struct ggml_tensor * ggml_cpy_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // make contiguous + GGML_API struct ggml_tensor * ggml_cont( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // make contiguous, in-place + GGML_API struct ggml_tensor * ggml_cont_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // make contiguous, with new shape + GGML_API struct ggml_tensor * ggml_cont_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0); + + GGML_API struct ggml_tensor * ggml_cont_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1); + + GGML_API struct ggml_tensor * ggml_cont_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + GGML_API struct ggml_tensor * ggml_cont_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + + // return view(a), b specifies the new shape + // TODO: when we start computing gradient, make a copy instead of view + GGML_API struct ggml_tensor * ggml_reshape( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // return view(a) + // TODO: when we start computing gradient, make a copy instead of view + GGML_API struct ggml_tensor * ggml_reshape_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0); + + GGML_API struct ggml_tensor * ggml_reshape_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1); + + // return view(a) + // TODO: when we start computing gradient, make a copy instead of view + GGML_API struct ggml_tensor * ggml_reshape_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + GGML_API struct ggml_tensor * ggml_reshape_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + + // offset in bytes + GGML_API struct ggml_tensor * ggml_view_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + size_t offset); + + GGML_API struct ggml_tensor * ggml_view_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + size_t nb1, // row stride in bytes + size_t offset); + + GGML_API struct ggml_tensor * ggml_view_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + size_t nb1, // row stride in bytes + size_t nb2, // slice stride in bytes + size_t offset); + + GGML_API struct ggml_tensor * ggml_view_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + size_t nb1, // row stride in bytes + size_t nb2, // slice stride in bytes + size_t nb3, + size_t offset); + + GGML_API struct ggml_tensor * ggml_permute( + struct ggml_context * ctx, + struct ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3); + + // alias for ggml_permute(ctx, a, 1, 0, 2, 3) + GGML_API struct ggml_tensor * ggml_transpose( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // supports 3D: a->ne[2] == b->ne[1] + GGML_API struct ggml_tensor * ggml_get_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_get_rows_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c); + + GGML_API struct ggml_tensor * ggml_diag( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // set elements above the diagonal to -INF + GGML_API struct ggml_tensor * ggml_diag_mask_inf( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + // set elements above the diagonal to 0 + GGML_API struct ggml_tensor * ggml_diag_mask_zero( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + GGML_API struct ggml_tensor * ggml_soft_max( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_soft_max_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // fused soft_max(a*scale + mask) + // mask is optional + GGML_API struct ggml_tensor * ggml_soft_max_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale); + + GGML_API struct ggml_tensor * ggml_soft_max_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_soft_max_back_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // rotary position embedding + // if mode & 1 == 1, skip n_past elements (DEPRECATED) + // if mode & 2 == 1, GPT-NeoX style + // if mode & 4 == 1, ChatGLM style + // + // b is an int32 vector with size a->ne[2], it contains the positions + GGML_API struct ggml_tensor * ggml_rope( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_rope_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx); + + // custom RoPE + GGML_API struct ggml_tensor * ggml_rope_custom( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_rope_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + // compute correction dims for YaRN RoPE scaling + void ggml_rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); + + // xPos RoPE, in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + float base, + bool down); + + // rotary position embedding backward, i.e compute dx from dy + // a - dy + GGML_API struct ggml_tensor * ggml_rope_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + float xpos_base, + bool xpos_down); + + // alibi position embedding + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_alibi( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_head, + float bias_max); + + // clamp + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_clamp( + struct ggml_context * ctx, + struct ggml_tensor * a, + float min, + float max); + + GGML_API struct ggml_tensor * ggml_im2col( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D); + + GGML_API struct ggml_tensor * ggml_conv_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, // stride + int p0, // padding + int d0); // dilation + + // conv_1d with padding = half + // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d) + GGML_API struct ggml_tensor* ggml_conv_1d_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s, + int d); + + GGML_API struct ggml_tensor * ggml_conv_transpose_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0); + + GGML_API struct ggml_tensor * ggml_conv_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1); + + + // kernel size is a->ne[0] x a->ne[1] + // stride is equal to kernel size + // padding is zero + // example: + // a: 16 16 3 768 + // b: 1024 1024 3 1 + // res: 64 64 768 1 + // used in sam + GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // kernel size is a->ne[0] x a->ne[1] + // stride is 1 + // padding is half + // example: + // a: 3 3 256 256 + // b: 64 64 256 1 + // res: 64 64 256 1 + // used in sam + GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int stride); + + enum ggml_op_pool { + GGML_OP_POOL_MAX, + GGML_OP_POOL_AVG, + GGML_OP_POOL_COUNT, + }; + + GGML_API struct ggml_tensor * ggml_pool_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_op_pool op, + int k0, // kernel size + int s0, // stride + int p0); // padding + + // the result will have 2*p0 padding for the first dimension + // and 2*p1 padding for the second dimension + GGML_API struct ggml_tensor * ggml_pool_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + float p0, + float p1); + + // nearest interpolate + // used in stable-diffusion + GGML_API struct ggml_tensor * ggml_upscale( + struct ggml_context * ctx, + struct ggml_tensor * a, + int scale_factor); + + // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0] + GGML_API struct ggml_tensor * ggml_pad( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, + int p1, + int p2, + int p3); + + // sort rows + enum ggml_sort_order { + GGML_SORT_ASC, + GGML_SORT_DESC, + }; + + GGML_API struct ggml_tensor * ggml_argsort( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_sort_order order); + + // top k elements per row + GGML_API struct ggml_tensor * ggml_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k); + + GGML_API struct ggml_tensor * ggml_flash_attn( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + bool masked); + + GGML_API struct ggml_tensor * ggml_flash_attn_back( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * d, + bool masked); + + GGML_API struct ggml_tensor * ggml_flash_ff( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b0, + struct ggml_tensor * b1, + struct ggml_tensor * c0, + struct ggml_tensor * c1); + + // partition into non-overlapping windows with padding if needed + // example: + // a: 768 64 64 1 + // w: 14 + // res: 768 14 14 25 + // used in sam + GGML_API struct ggml_tensor * ggml_win_part( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w); + + // reverse of ggml_win_part + // used in sam + GGML_API struct ggml_tensor * ggml_win_unpart( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w0, + int h0, + int w); + + GGML_API struct ggml_tensor * ggml_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op); + + GGML_API struct ggml_tensor * ggml_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op); + + // used in sam + GGML_API struct ggml_tensor * ggml_get_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + int qh, + int kh); + + // used in sam + GGML_API struct ggml_tensor * ggml_add_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph); + + GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph); + + // custom operators + + typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); + typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *); + + typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *); + typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); + typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_unary_op_f32_t fun), + "use ggml_map_custom1 instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_unary_op_f32_t fun), + "use ggml_map_custom1_inplace instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_binary_op_f32_t fun), + "use ggml_map_custom2 instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_binary_op_f32_t fun), + "use ggml_map_custom2_inplace instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_custom1_op_f32_t fun), + "use ggml_map_custom1 instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_custom1_op_f32_t fun), + "use ggml_map_custom1_inplace instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_custom2_op_f32_t fun), + "use ggml_map_custom2 instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_custom2_op_f32_t fun), + "use ggml_map_custom2_inplace instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + ggml_custom3_op_f32_t fun), + "use ggml_map_custom3 instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + ggml_custom3_op_f32_t fun), + "use ggml_map_custom3_inplace instead"); + + // custom operators v2 + + typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); + typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata); + typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata); + + #define GGML_N_TASKS_MAX -1 + + GGML_API struct ggml_tensor * ggml_map_custom1( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_custom1_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_custom1_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom2( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_custom2_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom2_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_custom2_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom3( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + ggml_custom3_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom3_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + ggml_custom3_op_t fun, + int n_tasks, + void * userdata); + + // loss function + + GGML_API struct ggml_tensor * ggml_cross_entropy_loss( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c); + + // + // automatic differentiation + // + + GGML_API void ggml_set_param( + struct ggml_context * ctx, + struct ggml_tensor * tensor); + + + GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep); + + // graph allocation in a context + GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false + GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads); + GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph); + GGML_API struct ggml_cgraph ggml_graph_view (struct ggml_cgraph * cgraph, int i0, int i1); + GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst); + GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads + GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph); + + GGML_API size_t ggml_graph_overhead(void); + GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads); + + // ggml_graph_plan() has to be called before ggml_graph_compute() + // when plan.work_size > 0, caller must allocate memory for plan.work_data + GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/); + GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); + + // same as ggml_graph_compute() but the work data is allocated as a part of the context + // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data + GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads); + + GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name); + + GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname); + GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval); + + // print info and performance information for the graph + GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); + + // dump the graph into a file using the dot format + GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); + + // build gradient checkpointing backward graph gb for gf using provided checkpoints + // gb_tmp will contain original backward graph with rewritten backward process nodes, + // but without the second forward pass nodes. + GGML_API void ggml_build_backward_gradient_checkpointing( + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + struct ggml_cgraph * gb_tmp, + struct ggml_tensor * * checkpoints, + int n_checkpoints); + // + // optimization + // + + // optimization methods + enum ggml_opt_type { + GGML_OPT_ADAM, + GGML_OPT_LBFGS, + }; + + // linesearch methods + enum ggml_linesearch { + GGML_LINESEARCH_DEFAULT = 1, + + GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0, + GGML_LINESEARCH_BACKTRACKING_WOLFE = 1, + GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2, + }; + + // optimization return values + enum ggml_opt_result { + GGML_OPT_OK = 0, + GGML_OPT_DID_NOT_CONVERGE, + GGML_OPT_NO_CONTEXT, + GGML_OPT_INVALID_WOLFE, + GGML_OPT_FAIL, + GGML_OPT_CANCEL, + + GGML_LINESEARCH_FAIL = -128, + GGML_LINESEARCH_MINIMUM_STEP, + GGML_LINESEARCH_MAXIMUM_STEP, + GGML_LINESEARCH_MAXIMUM_ITERATIONS, + GGML_LINESEARCH_INVALID_PARAMETERS, + }; + + typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel); + typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); + + // optimization parameters + // + // see ggml.c (ggml_opt_default_params) for default values + // + struct ggml_opt_params { + enum ggml_opt_type type; + + size_t graph_size; + + int n_threads; + + // delta-based convergence test + // + // if past == 0 - disabled + // if past > 0: + // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|) + // + int past; + float delta; + + // maximum number of iterations without improvement + // + // if 0 - disabled + // if > 0: + // assume convergence if no cost improvement in this number of iterations + // + int max_no_improvement; + + bool print_forward_graph; + bool print_backward_graph; + + int n_gradient_accumulation; + + // ADAM parameters + struct { + int n_iter; + + float sched; // schedule multiplier (fixed, decay or warmup) + float decay; // weight decay for AdamW, use 0.0f to disable + int decay_min_ndim; // minimum number of tensor dimension to apply weight decay + float alpha; // learning rate + float beta1; + float beta2; + float eps; // epsilon for numerical stability + float eps_f; // epsilon for convergence test + float eps_g; // epsilon for convergence test + float gclip; // gradient clipping + } adam; + + // LBFGS parameters + struct { + int m; // number of corrections to approximate the inv. Hessian + int n_iter; + int max_linesearch; + + float eps; // convergence tolerance + float ftol; // line search tolerance + float wolfe; + float min_step; + float max_step; + + enum ggml_linesearch linesearch; + } lbfgs; + }; + + struct ggml_opt_context { + struct ggml_context * ctx; + struct ggml_opt_params params; + + int iter; + int64_t nx; // number of parameter elements + + bool just_initialized; + + float loss_before; + float loss_after; + + struct { + struct ggml_tensor * g; // current gradient + struct ggml_tensor * m; // first moment + struct ggml_tensor * v; // second moment + struct ggml_tensor * pf; // past function values + float fx_best; + float fx_prev; + int n_no_improvement; + } adam; + + struct { + struct ggml_tensor * x; // current parameters + struct ggml_tensor * xp; // previous parameters + struct ggml_tensor * g; // current gradient + struct ggml_tensor * gp; // previous gradient + struct ggml_tensor * d; // search direction + struct ggml_tensor * pf; // past function values + struct ggml_tensor * lmal; // the L-BFGS memory alpha + struct ggml_tensor * lmys; // the L-BFGS memory ys + struct ggml_tensor * lms; // the L-BFGS memory s + struct ggml_tensor * lmy; // the L-BFGS memory y + float fx_best; + float step; + int j; + int k; + int end; + int n_no_improvement; + } lbfgs; + }; + + GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type); + + // optimize the function defined by the tensor f + GGML_API enum ggml_opt_result ggml_opt( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f); + + // initialize optimizer context + GGML_API void ggml_opt_init( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_opt_params params, + int64_t nx); + + // continue optimizing the function defined by the tensor f + GGML_API enum ggml_opt_result ggml_opt_resume( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_tensor * f); + + // continue optimizing the function defined by the tensor f + GGML_API enum ggml_opt_result ggml_opt_resume_g( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data); + + // + // quantization + // + + // TODO: these would probably get removed in favor of the more general ggml_quantize_chunk + GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); + + GGML_API size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist); + + GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); + + // + // gguf + // + + enum gguf_type { + GGUF_TYPE_UINT8 = 0, + GGUF_TYPE_INT8 = 1, + GGUF_TYPE_UINT16 = 2, + GGUF_TYPE_INT16 = 3, + GGUF_TYPE_UINT32 = 4, + GGUF_TYPE_INT32 = 5, + GGUF_TYPE_FLOAT32 = 6, + GGUF_TYPE_BOOL = 7, + GGUF_TYPE_STRING = 8, + GGUF_TYPE_ARRAY = 9, + GGUF_TYPE_UINT64 = 10, + GGUF_TYPE_INT64 = 11, + GGUF_TYPE_FLOAT64 = 12, + GGUF_TYPE_COUNT, // marks the end of the enum + }; + + struct gguf_context; + + struct gguf_init_params { + bool no_alloc; + + // if not NULL, create a ggml_context and allocate the tensor data in it + struct ggml_context ** ctx; + }; + + GGML_API struct gguf_context * gguf_init_empty(void); + GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); + //GGML_API struct gguf_context * gguf_init_from_buffer(..); + + GGML_API void gguf_free(struct gguf_context * ctx); + + GGML_API const char * gguf_type_name(enum gguf_type type); + + GGML_API int gguf_get_version (const struct gguf_context * ctx); + GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx); + GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); + GGML_API void * gguf_get_data (const struct gguf_context * ctx); + + GGML_API int gguf_get_n_kv(const struct gguf_context * ctx); + GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key); + GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id); + + GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id); + GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id); + + // will abort if the wrong type is used for the key + GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int key_id); + GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int key_id); + GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int key_id); + GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int key_id); + GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int key_id); + GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int key_id); + GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int key_id); + GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int key_id); + GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int key_id); + GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id); + GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id); + GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id); + GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id); + GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id); + GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id); + GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i); + + GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx); + GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name); + GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i); + GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i); + + // overrides existing values or adds a new one + GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val); + GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val); + GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val); + GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val); + GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val); + GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val); + GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val); + GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val); + GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val); + GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val); + GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val); + GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val); + GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n); + GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n); + + // set or add KV pairs from another context + GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src); + + // manage tensor info + GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor); + GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type); + GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size); + + // writing gguf files can be done in 2 ways: + // + // - write the entire gguf_context to a binary file in a single pass: + // + // gguf_write_to_file(ctx, fname); + // + // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data: + // + // FILE * f = fopen(fname, "wb"); + // fseek(f, gguf_get_meta_size(ctx), SEEK_SET); + // fwrite(f, ...); + // void * data = gguf_meta_get_meta_data(ctx); + // fseek(f, 0, SEEK_SET); + // fwrite(f, data, gguf_get_meta_size(ctx)); + // free(data); + // fclose(f); + // + + // write the entire context to a binary file + GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta); + + // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding + GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx); + GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data); + + // + // system info + // + + GGML_API int ggml_cpu_has_avx (void); + GGML_API int ggml_cpu_has_avx2 (void); + GGML_API int ggml_cpu_has_avx512 (void); + GGML_API int ggml_cpu_has_avx512_vbmi(void); + GGML_API int ggml_cpu_has_avx512_vnni(void); + GGML_API int ggml_cpu_has_fma (void); + GGML_API int ggml_cpu_has_neon (void); + GGML_API int ggml_cpu_has_arm_fma (void); + GGML_API int ggml_cpu_has_metal (void); + GGML_API int ggml_cpu_has_f16c (void); + GGML_API int ggml_cpu_has_fp16_va (void); + GGML_API int ggml_cpu_has_wasm_simd (void); + GGML_API int ggml_cpu_has_blas (void); + GGML_API int ggml_cpu_has_cublas (void); + GGML_API int ggml_cpu_has_clblast (void); + GGML_API int ggml_cpu_has_gpublas (void); + GGML_API int ggml_cpu_has_sse3 (void); + GGML_API int ggml_cpu_has_ssse3 (void); + GGML_API int ggml_cpu_has_vsx (void); + + // + // Internal types and functions exposed for tests and benchmarks + // + +#ifdef __cplusplus +// restrict not standard in C++ +#define GGML_RESTRICT +#else +#define GGML_RESTRICT restrict +#endif + typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); + typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); + typedef void (*ggml_vec_dot_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y); + + typedef struct { + const char * type_name; + int blck_size; + size_t type_size; + bool is_quantized; + ggml_to_float_t to_float; + ggml_from_float_t from_float; + ggml_from_float_t from_float_reference; + ggml_vec_dot_t vec_dot; + enum ggml_type vec_dot_type; + } ggml_type_traits_t; + + GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); + +#ifdef __cplusplus +} +#endif diff --git a/applications/src/llama/gpubuild/compile_commands.json b/applications/src/llama/gpubuild/compile_commands.json new file mode 100644 index 000000000..7a0d28722 --- /dev/null +++ b/applications/src/llama/gpubuild/compile_commands.json @@ -0,0 +1,79 @@ +[ + { + "command": "c++ -c -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS -DK_QUANTS_PER_ITERATION=2 -D_GNU_SOURCE -D_XOPEN_SOURCE=600 -O3 -DNDEBUG -std=gnu++11 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wmissing-declarations -Wmissing-noreturn -Wno-array-bounds -Wno-format-truncation -Wextra-semi -o CMakeFiles/build_info.dir/build-info.cpp.o /export/users/placeholder/project/llama.cpp/common/build-info.cpp", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild/common", + "file": "/export/users/placeholder/project/llama.cpp/common/build-info.cpp" + }, + { + "command": "cc -c -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS -DK_QUANTS_PER_ITERATION=2 -D_GNU_SOURCE -D_XOPEN_SOURCE=600 -I/export/users/placeholder/project/llama.cpp/. -isystem /usr/local/cuda/include -O3 -DNDEBUG -std=gnu11 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Wdouble-promotion -o CMakeFiles/ggml.dir/ggml.c.o /export/users/placeholder/project/llama.cpp/ggml.c", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild", + "file": "/export/users/placeholder/project/llama.cpp/ggml.c" + }, + { + "command": "cc -c -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS -DK_QUANTS_PER_ITERATION=2 -D_GNU_SOURCE -D_XOPEN_SOURCE=600 -I/export/users/placeholder/project/llama.cpp/. -isystem /usr/local/cuda/include -O3 -DNDEBUG -std=gnu11 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Wdouble-promotion -o CMakeFiles/ggml.dir/ggml-alloc.c.o /export/users/placeholder/project/llama.cpp/ggml-alloc.c", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild", + "file": "/export/users/placeholder/project/llama.cpp/ggml-alloc.c" + }, + { + "command": "cc -c -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS -DK_QUANTS_PER_ITERATION=2 -D_GNU_SOURCE -D_XOPEN_SOURCE=600 -I/export/users/placeholder/project/llama.cpp/. -isystem /usr/local/cuda/include -O3 -DNDEBUG -std=gnu11 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Wdouble-promotion -o CMakeFiles/ggml.dir/ggml-backend.c.o /export/users/placeholder/project/llama.cpp/ggml-backend.c", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild", + "file": "/export/users/placeholder/project/llama.cpp/ggml-backend.c" + }, + { + "command": "cc -c -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS -DK_QUANTS_PER_ITERATION=2 -D_GNU_SOURCE -D_XOPEN_SOURCE=600 -I/export/users/placeholder/project/llama.cpp/. -isystem /usr/local/cuda/include -O3 -DNDEBUG -std=gnu11 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Wdouble-promotion -o CMakeFiles/ggml.dir/ggml-quants.c.o /export/users/placeholder/project/llama.cpp/ggml-quants.c", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild", + "file": "/export/users/placeholder/project/llama.cpp/ggml-quants.c" + }, + { + "command": "nvcc -c -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS -DK_QUANTS_PER_ITERATION=2 -D_GNU_SOURCE -D_XOPEN_SOURCE=600 -I/export/users/placeholder/project/llama.cpp/. -isystem /usr/local/cuda/include -O3 -DNDEBUG -std=c++11 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wmissing-declarations -Wmissing-noreturn -Wno-pedantic -Wno-array-bounds -Wno-format-truncation -Wextra-semi -o CMakeFiles/ggml.dir/ggml-cuda.cu.o -D__CUDACC__=1 /export/users/placeholder/project/llama.cpp/ggml-cuda.cu", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild", + "file": "/export/users/placeholder/project/llama.cpp/ggml-cuda.cu" + }, + { + "command": "c++ -c -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS -DK_QUANTS_PER_ITERATION=2 -D_GNU_SOURCE -D_XOPEN_SOURCE=600 -I/export/users/placeholder/project/llama.cpp/. -isystem /usr/local/cuda/include -O3 -DNDEBUG -std=gnu++11 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wmissing-declarations -Wmissing-noreturn -Wno-array-bounds -Wno-format-truncation -Wextra-semi -o CMakeFiles/llama.dir/llama.cpp.o /export/users/placeholder/project/llama.cpp/llama.cpp", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild", + "file": "/export/users/placeholder/project/llama.cpp/llama.cpp" + }, + { + "command": "ar qc libllama.a CMakeFiles/llama.dir/llama.cpp.o CMakeFiles/ggml.dir/ggml.c.o CMakeFiles/ggml.dir/ggml-alloc.c.o CMakeFiles/ggml.dir/ggml-backend.c.o CMakeFiles/ggml.dir/ggml-quants.c.o CMakeFiles/ggml.dir/ggml-cuda.cu.o", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild" + }, + { + "command": "c++ -c -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS -DK_QUANTS_PER_ITERATION=2 -D_GNU_SOURCE -D_XOPEN_SOURCE=600 -I/export/users/placeholder/project/llama.cpp/common/. -I/export/users/placeholder/project/llama.cpp/. -O3 -DNDEBUG -std=gnu++11 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wmissing-declarations -Wmissing-noreturn -Wno-array-bounds -Wno-format-truncation -Wextra-semi -o CMakeFiles/common.dir/common.cpp.o /export/users/placeholder/project/llama.cpp/common/common.cpp", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild/common", + "file": "/export/users/placeholder/project/llama.cpp/common/common.cpp" + }, + { + "command": "c++ -c -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS -DK_QUANTS_PER_ITERATION=2 -D_GNU_SOURCE -D_XOPEN_SOURCE=600 -I/export/users/placeholder/project/llama.cpp/common/. -I/export/users/placeholder/project/llama.cpp/. -O3 -DNDEBUG -std=gnu++11 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wmissing-declarations -Wmissing-noreturn -Wno-array-bounds -Wno-format-truncation -Wextra-semi -o CMakeFiles/common.dir/sampling.cpp.o /export/users/placeholder/project/llama.cpp/common/sampling.cpp", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild/common", + "file": "/export/users/placeholder/project/llama.cpp/common/sampling.cpp" + }, + { + "command": "c++ -c -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS -DK_QUANTS_PER_ITERATION=2 -D_GNU_SOURCE -D_XOPEN_SOURCE=600 -I/export/users/placeholder/project/llama.cpp/common/. -I/export/users/placeholder/project/llama.cpp/. -O3 -DNDEBUG -std=gnu++11 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wmissing-declarations -Wmissing-noreturn -Wno-array-bounds -Wno-format-truncation -Wextra-semi -o CMakeFiles/common.dir/console.cpp.o /export/users/placeholder/project/llama.cpp/common/console.cpp", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild/common", + "file": "/export/users/placeholder/project/llama.cpp/common/console.cpp" + }, + { + "command": "c++ -c -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS -DK_QUANTS_PER_ITERATION=2 -D_GNU_SOURCE -D_XOPEN_SOURCE=600 -I/export/users/placeholder/project/llama.cpp/common/. -I/export/users/placeholder/project/llama.cpp/. -O3 -DNDEBUG -std=gnu++11 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wmissing-declarations -Wmissing-noreturn -Wno-array-bounds -Wno-format-truncation -Wextra-semi -o CMakeFiles/common.dir/grammar-parser.cpp.o /export/users/placeholder/project/llama.cpp/common/grammar-parser.cpp", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild/common", + "file": "/export/users/placeholder/project/llama.cpp/common/grammar-parser.cpp" + }, + { + "command": "c++ -c -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS -DK_QUANTS_PER_ITERATION=2 -D_GNU_SOURCE -D_XOPEN_SOURCE=600 -I/export/users/placeholder/project/llama.cpp/common/. -I/export/users/placeholder/project/llama.cpp/. -O3 -DNDEBUG -std=gnu++11 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wmissing-declarations -Wmissing-noreturn -Wno-array-bounds -Wno-format-truncation -Wextra-semi -o CMakeFiles/common.dir/train.cpp.o /export/users/placeholder/project/llama.cpp/common/train.cpp", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild/common", + "file": "/export/users/placeholder/project/llama.cpp/common/train.cpp" + }, + { + "command": "ar qc libcommon.a CMakeFiles/common.dir/common.cpp.o CMakeFiles/common.dir/sampling.cpp.o CMakeFiles/common.dir/console.cpp.o CMakeFiles/common.dir/grammar-parser.cpp.o CMakeFiles/common.dir/train.cpp.o CMakeFiles/build_info.dir/build-info.cpp.o", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild/common" + }, + { + "command": "c++ -c -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS -DK_QUANTS_PER_ITERATION=2 -D_GNU_SOURCE -D_XOPEN_SOURCE=600 -I/export/users/placeholder/project/llama.cpp/examples -I/export/users/placeholder/project/llama.cpp/common/. -I/export/users/placeholder/project/llama.cpp/. -O3 -DNDEBUG -std=gnu++11 -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wmissing-declarations -Wmissing-noreturn -Wno-array-bounds -Wno-format-truncation -Wextra-semi -o CMakeFiles/main.dir/main.cpp.o /export/users/placeholder/project/llama.cpp/examples/main/main.cpp", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild/examples/main", + "file": "/export/users/placeholder/project/llama.cpp/examples/main/main.cpp" + }, + { + "command": "ld -plugin /usr/lib/gcc/x86_64-linux-gnu/11/liblto_plugin.so -plugin-opt=/usr/lib/gcc/x86_64-linux-gnu/11/lto-wrapper -plugin-opt=-fresolution=/export/users/wanghao2/temp/ccDJo76e.res -plugin-opt=-pass-through=-lgcc_s -plugin-opt=-pass-through=-lgcc -plugin-opt=-pass-through=-lc -plugin-opt=-pass-through=-lgcc_s -plugin-opt=-pass-through=-lgcc --build-id --eh-frame-hdr -melf_x86_64 --hash-style=gnu --as-needed -dynamic-linker /lib64/ld-linux-x86-64.so.2 -pie -o ../../bin/main /usr/lib/gcc/x86_64-linux-gnu/11/../../../x86_64-linux-gnu/Scrt1.o /usr/lib/gcc/x86_64-linux-gnu/11/../../../x86_64-linux-gnu/crti.o /usr/lib/gcc/x86_64-linux-gnu/11/crtbeginS.o CMakeFiles/main.dir/main.cpp.o -rpath /usr/local/cuda-12.0/targets/x86_64-linux/lib: ../../common/libcommon.a ../../libllama.a /usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudart.so /usr/local/cuda-12.0/targets/x86_64-linux/lib/libcublas.so /usr/local/cuda-12.0/targets/x86_64-linux/lib/libculibos.a /usr/local/cuda-12.0/targets/x86_64-linux/lib/libcublasLt.so /usr/lib/gcc/x86_64-linux-gnu/11/crtendS.o /usr/lib/gcc/x86_64-linux-gnu/11/../../../x86_64-linux-gnu/crtn.o", + "directory": "/export/users/placeholder/project/llama.cpp/gpubuild/examples/main" + } +] \ No newline at end of file diff --git a/applications/src/llama/llama.cpp b/applications/src/llama/llama.cpp new file mode 100644 index 000000000..99facbf77 --- /dev/null +++ b/applications/src/llama/llama.cpp @@ -0,0 +1,10360 @@ +#define LLAMA_API_INTERNAL +#include "llama.h" + +#include "unicode.h" + +#include "ggml.h" + +#include "ggml-alloc.h" + +#ifdef GGML_USE_CUBLAS +# include "ggml-cuda.h" +#elif defined(GGML_USE_CLBLAST) +# include "ggml-opencl.h" +#endif + +#ifdef GGML_USE_METAL +# include "ggml-metal.h" +#endif +#ifdef GGML_USE_MPI +# include "ggml-mpi.h" +#endif +#ifndef QK_K +# ifdef GGML_QKK_64 +# define QK_K 64 +# else +# define QK_K 256 +# endif +#endif + +#ifdef __has_include + #if __has_include() + #include + #if defined(_POSIX_MAPPED_FILES) + #include + #endif + #if defined(_POSIX_MEMLOCK_RANGE) + #include + #endif + #endif +#endif + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #ifndef NOMINMAX + #define NOMINMAX + #endif + #include + #include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define LLAMA_ATTRIBUTE_FORMAT(...) +#endif + +#define LLAMA_MAX_NODES 8192 +#define LLAMA_MAX_EXPERTS 8 + +// +// logging +// + +LLAMA_ATTRIBUTE_FORMAT(2, 3) +static void llama_log_internal (ggml_log_level level, const char* format, ...); +static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data); + +#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) +#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) + +// +// helpers +// + +static size_t utf8_len(char src) { + const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t highbits = static_cast(src) >> 4; + return lookup[highbits]; +} + +static void replace_all(std::string & s, const std::string & search, const std::string & replace) { + std::string result; + for (size_t pos = 0; ; pos += search.length()) { + auto new_pos = s.find(search, pos); + if (new_pos == std::string::npos) { + result += s.substr(pos, s.size() - pos); + break; + } + result += s.substr(pos, new_pos - pos) + replace; + pos = new_pos; + } + s = std::move(result); +} + +static bool is_float_close(float a, float b, float abs_tol) { + // Check for non-negative tolerance + if (abs_tol < 0.0) { + throw std::invalid_argument("Tolerance must be non-negative"); + } + + // Exact equality check + if (a == b) { + return true; + } + + // Check for infinities + if (std::isinf(a) || std::isinf(b)) { + return false; + } + + // Regular comparison using the provided absolute tolerance + return std::fabs(b - a) <= abs_tol; +} + +#ifdef GGML_USE_CPU_HBM +#include +#endif + +static void zeros(std::ofstream & file, size_t n) { + char zero = 0; + for (size_t i = 0; i < n; ++i) { + file.write(&zero, 1); + } +} + +LLAMA_ATTRIBUTE_FORMAT(1, 2) +static std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +// +// gguf constants (sync with gguf.py) +// + +enum llm_arch { + LLM_ARCH_LLAMA, + LLM_ARCH_FALCON, + LLM_ARCH_BAICHUAN, + LLM_ARCH_GPT2, + LLM_ARCH_GPTJ, + LLM_ARCH_GPTNEOX, + LLM_ARCH_MPT, + LLM_ARCH_STARCODER, + LLM_ARCH_PERSIMMON, + LLM_ARCH_REFACT, + LLM_ARCH_BLOOM, + LLM_ARCH_STABLELM, + LLM_ARCH_QWEN, + LLM_ARCH_UNKNOWN, +}; + +static std::map LLM_ARCH_NAMES = { + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, + { LLM_ARCH_BAICHUAN, "baichuan" }, + { LLM_ARCH_STARCODER, "starcoder" }, + { LLM_ARCH_PERSIMMON, "persimmon" }, + { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BLOOM, "bloom" }, + { LLM_ARCH_STABLELM, "stablelm" }, + { LLM_ARCH_QWEN, "qwen" }, +}; + +enum llm_kv { + LLM_KV_GENERAL_ARCHITECTURE, + LLM_KV_GENERAL_QUANTIZATION_VERSION, + LLM_KV_GENERAL_ALIGNMENT, + LLM_KV_GENERAL_NAME, + LLM_KV_GENERAL_AUTHOR, + LLM_KV_GENERAL_URL, + LLM_KV_GENERAL_DESCRIPTION, + LLM_KV_GENERAL_LICENSE, + LLM_KV_GENERAL_SOURCE_URL, + LLM_KV_GENERAL_SOURCE_HF_REPO, + + LLM_KV_CONTEXT_LENGTH, + LLM_KV_EMBEDDING_LENGTH, + LLM_KV_BLOCK_COUNT, + LLM_KV_FEED_FORWARD_LENGTH, + LLM_KV_USE_PARALLEL_RESIDUAL, + LLM_KV_TENSOR_DATA_LAYOUT, + LLM_KV_EXPERT_COUNT, + LLM_KV_EXPERT_USED_COUNT, + + LLM_KV_ATTENTION_HEAD_COUNT, + LLM_KV_ATTENTION_HEAD_COUNT_KV, + LLM_KV_ATTENTION_MAX_ALIBI_BIAS, + LLM_KV_ATTENTION_CLAMP_KQV, + LLM_KV_ATTENTION_LAYERNORM_EPS, + LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, + + LLM_KV_ROPE_DIMENSION_COUNT, + LLM_KV_ROPE_FREQ_BASE, + LLM_KV_ROPE_SCALE_LINEAR, + LLM_KV_ROPE_SCALING_TYPE, + LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, + LLM_KV_ROPE_SCALING_FINETUNED, + + LLM_KV_TOKENIZER_MODEL, + LLM_KV_TOKENIZER_LIST, + LLM_KV_TOKENIZER_TOKEN_TYPE, + LLM_KV_TOKENIZER_SCORES, + LLM_KV_TOKENIZER_MERGES, + LLM_KV_TOKENIZER_BOS_ID, + LLM_KV_TOKENIZER_EOS_ID, + LLM_KV_TOKENIZER_UNK_ID, + LLM_KV_TOKENIZER_SEP_ID, + LLM_KV_TOKENIZER_PAD_ID, + LLM_KV_TOKENIZER_ADD_BOS, + LLM_KV_TOKENIZER_ADD_EOS, + LLM_KV_TOKENIZER_HF_JSON, + LLM_KV_TOKENIZER_RWKV, +}; + +static std::map LLM_KV_NAMES = { + { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, + { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, + { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, + { LLM_KV_GENERAL_NAME, "general.name" }, + { LLM_KV_GENERAL_AUTHOR, "general.author" }, + { LLM_KV_GENERAL_URL, "general.url" }, + { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, + { LLM_KV_GENERAL_LICENSE, "general.license" }, + { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, + { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, + + { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, + { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_BLOCK_COUNT, "%s.block_count" }, + { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, + { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, + { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, + { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, + { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, + + { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, + { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, + { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, + { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, + { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, + { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, + + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, + { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, + { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, + { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, + { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, + { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, + { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, + { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, + { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, +}; + +struct LLM_KV { + LLM_KV(llm_arch arch) : arch(arch) {} + + llm_arch arch; + + std::string operator()(llm_kv kv) const { + return ::format(LLM_KV_NAMES[kv].c_str(), LLM_ARCH_NAMES[arch].c_str()); + } +}; + +enum llm_tensor { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_POS_EMBD, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_NORM_2, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN_EXP, + LLM_TENSOR_FFN_GATE_EXP, + LLM_TENSOR_FFN_UP_EXP, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, +}; + +static std::map> LLM_TENSOR_NAMES = { + { + LLM_ARCH_LLAMA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + }, + }, + { + LLM_ARCH_BAICHUAN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_FALCON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_GPT2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, + { + LLM_ARCH_GPTJ, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, + { + LLM_ARCH_GPTNEOX, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_PERSIMMON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd"}, + { LLM_TENSOR_OUTPUT_NORM, "output_norm"}, + { LLM_TENSOR_OUTPUT, "output"}, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm"}, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv"}, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output"}, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm"}, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm"}, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm"}, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down"}, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up"}, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd"}, + }, + }, + { + LLM_ARCH_MPT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_STARCODER, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_REFACT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_BLOOM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_STABLELM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + + { + LLM_ARCH_UNKNOWN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, +}; + +static llm_arch llm_arch_from_string(const std::string & name) { + for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT + if (kv.second == name) { + return kv.first; + } + } + + return LLM_ARCH_UNKNOWN; +} + +// helper to handle gguf constants +// usage: +// +// const auto tn = LLM_TN(LLM_ARCH_LLAMA); +// +// std::string name = tn(LLM_TENSOR_OUTPUT); -> "output" +// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias" +// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight" +// +struct LLM_TN { + LLM_TN(llm_arch arch) : arch(arch) {} + + llm_arch arch; + + std::string operator()(llm_tensor tensor) const { + return LLM_TENSOR_NAMES[arch].at(tensor); + } + + std::string operator()(llm_tensor tensor, const std::string & suffix) const { + return LLM_TENSOR_NAMES[arch].at(tensor) + "." + suffix; + } + + std::string operator()(llm_tensor tensor, int bid) const { + return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid); + } + + std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const { + return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix; + } + + std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const { + return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid, xid) + "." + suffix; + } +}; + +// +// gguf helpers +// + +static std::map LLAMA_ROPE_SCALING_TYPES = { + { LLAMA_ROPE_SCALING_NONE, "none" }, + { LLAMA_ROPE_SCALING_LINEAR, "linear" }, + { LLAMA_ROPE_SCALING_YARN, "yarn" }, +}; + +static int8_t llama_rope_scaling_type_from_string(const std::string & name) { + for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) { + if (kv.second == name) { + return kv.first; + } + } + + return LLAMA_ROPE_SCALING_UNSPECIFIED; +} + +static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) { + switch (type) { + case GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]); + case GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]); + case GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]); + case GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]); + case GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]); + case GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]); + case GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]); + case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]); + case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]); + case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]); + case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false"; + default: return format("unknown type %d", type); + } +} + +static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { + const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); + + switch (type) { + case GGUF_TYPE_STRING: + return gguf_get_val_str(ctx_gguf, i); + case GGUF_TYPE_ARRAY: + { + const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); + int arr_n = gguf_get_arr_n(ctx_gguf, i); + const void * data = gguf_get_arr_data(ctx_gguf, i); + std::stringstream ss; + ss << "["; + for (int j = 0; j < arr_n; j++) { + if (arr_type == GGUF_TYPE_STRING) { + std::string val = gguf_get_arr_str(ctx_gguf, i, j); + // escape quotes + replace_all(val, "\\", "\\\\"); + replace_all(val, "\"", "\\\""); + ss << '"' << val << '"'; + } else if (arr_type == GGUF_TYPE_ARRAY) { + ss << "???"; + } else { + ss << gguf_data_to_str(arr_type, data, j); + } + if (j < arr_n - 1) { + ss << ", "; + } + } + ss << "]"; + return ss.str(); + } + default: + return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0); + } +} + +// +// ggml helpers +// + +static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { + struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + + if (plan.work_size > 0) { + buf.resize(plan.work_size); + plan.work_data = buf.data(); + } + + ggml_graph_compute(graph, &plan); +} + +// +// llama helpers +// + +inline void * llama_host_malloc(size_t n) { +#ifdef GGML_USE_CUBLAS + if (ggml_cublas_loaded()) { + return ggml_cuda_host_malloc(n); + } else { + return malloc(n); + } +#elif GGML_USE_METAL + return ggml_metal_host_malloc(n); +#elif GGML_USE_CPU_HBM + return hbw_malloc(n); +#else + return malloc(n); +#endif +} + +inline void llama_host_free(void * ptr) { +#ifdef GGML_USE_CUBLAS + if (ggml_cublas_loaded()) { + return ggml_cuda_host_free(ptr); + } else { + return free(ptr); + } +#elif GGML_USE_METAL + return ggml_metal_host_free(ptr); +#elif GGML_USE_CPU_HBM + return hbw_free(ptr); +#else + return free(ptr); +#endif +} + +#if defined(_WIN32) +static std::string llama_format_win_err(DWORD err) { + LPSTR buf; + size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL); + if (!size) { + return "FormatMessageA failed"; + } + std::string ret(buf, size); + LocalFree(buf); + return ret; +} +#endif + +struct llama_buffer { + void * data = NULL; + size_t size = 0; + + // fallback to malloc / free + // useful in cases where CUDA can try to allocate PINNED memory + bool fallback = false; + + void resize(size_t n) { + llama_host_free(data); + + data = llama_host_malloc(n); + if (!data) { + fallback = true; + data = malloc(n); + } else { + fallback = false; + } + + GGML_ASSERT(data); + size = n; + } + + ~llama_buffer() { + if (data) { + if (fallback) { // NOLINT + free(data); + } else { + llama_host_free(data); + } + } + + data = NULL; + } +}; + +struct llama_file { + // use FILE * so we don't have to re-open the file to mmap + FILE * fp; + size_t size; + + llama_file(const char * fname, const char * mode) { + fp = std::fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); +#else + long ret = std::ftell(fp); +#endif + GGML_ASSERT(ret != -1); // this really shouldn't fail + return (size_t) ret; + } + + void seek(size_t offset, int whence) const { +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, whence); +#else + int ret = std::fseek(fp, (long) offset, whence); +#endif + GGML_ASSERT(ret == 0); // same + } + + void read_raw(void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + std::size_t ret = std::fread(ptr, len, 1, fp); + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret != 1) { + throw std::runtime_error(std::string("unexpectedly reached end of file")); + } + } + + uint32_t read_u32() const { + uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + + void write_raw(const void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + size_t ret = std::fwrite(ptr, len, 1, fp); + if (ret != 1) { + throw std::runtime_error(format("write error: %s", strerror(errno))); + } + } + + void write_u32(std::uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~llama_file() { + if (fp) { + std::fclose(fp); + } + } +}; + +struct llama_mmap { + void * addr; + size_t size; + + llama_mmap(const llama_mmap &) = delete; + +#ifdef _POSIX_MAPPED_FILES + static constexpr bool SUPPORTED = true; + + llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false) { + size = file->size; + int fd = fileno(file->fp); + int flags = MAP_SHARED; + // prefetch/readahead impairs performance on NUMA systems + if (numa) { prefetch = 0; } +#ifdef __linux__ + if (prefetch) { flags |= MAP_POPULATE; } +#endif + addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0); + if (addr == MAP_FAILED) { + throw std::runtime_error(format("mmap failed: %s", strerror(errno))); + } + + if (prefetch > 0) { + // Advise the kernel to preload the mapped memory + if (posix_madvise(addr, std::min(file->size, prefetch), POSIX_MADV_WILLNEED)) { + fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n", + strerror(errno)); + } + } + if (numa) { + // advise the kernel not to use readahead + // (because the next page might not belong on the same node) + if (posix_madvise(addr, file->size, POSIX_MADV_RANDOM)) { + fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n", + strerror(errno)); + } + } + } + + ~llama_mmap() { + munmap(addr, size); + } +#elif defined(_WIN32) + static constexpr bool SUPPORTED = true; + + llama_mmap(struct llama_file * file, bool prefetch = true, bool numa = false) { + (void) numa; + + size = file->size; + + HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp)); + + HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + DWORD error = GetLastError(); + + if (hMapping == NULL) { + throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str())); + } + + addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); + error = GetLastError(); + CloseHandle(hMapping); + + if (addr == NULL) { + throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); + } + + if (prefetch) { + // PrefetchVirtualMemory is only present on Windows 8 and above, so we dynamically load it + BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG); + HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll"); + + // may fail on pre-Windows 8 systems + pPrefetchVirtualMemory = reinterpret_cast (GetProcAddress(hKernel32, "PrefetchVirtualMemory")); + + if (pPrefetchVirtualMemory) { + // advise the kernel to preload the mapped memory + WIN32_MEMORY_RANGE_ENTRY range; + range.VirtualAddress = addr; + range.NumberOfBytes = (SIZE_T)size; + if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) { + fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } + } + } + + ~llama_mmap() { + if (!UnmapViewOfFile(addr)) { + fprintf(stderr, "warning: UnmapViewOfFile failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + static constexpr bool SUPPORTED = false; + + llama_mmap(struct llama_file * file, bool prefetch = true, bool numa = false) { + (void) file; + (void) prefetch; + (void) numa; + + throw std::runtime_error(std::string("mmap not supported")); + } +#endif +}; + +// Represents some region of memory being locked using mlock or VirtualLock; +// will automatically unlock on destruction. +struct llama_mlock { + void * addr = NULL; + size_t size = 0; + + bool failed_already = false; + + llama_mlock() {} + llama_mlock(const llama_mlock &) = delete; + + ~llama_mlock() { + if (size) { + raw_unlock(addr, size); + } + } + + void init(void * ptr) { + GGML_ASSERT(addr == NULL && size == 0); // NOLINT + addr = ptr; + } + + void grow_to(size_t target_size) { + GGML_ASSERT(addr); + if (failed_already) { + return; + } + size_t granularity = lock_granularity(); + target_size = (target_size + granularity - 1) & ~(granularity - 1); + if (target_size > size) { + if (raw_lock((uint8_t *) addr + size, target_size - size)) { + size = target_size; + } else { + failed_already = true; + } + } + } + +#ifdef _POSIX_MEMLOCK_RANGE + static constexpr bool SUPPORTED = true; + + static size_t lock_granularity() { + return (size_t) sysconf(_SC_PAGESIZE); + } + + #ifdef __APPLE__ + #define MLOCK_SUGGESTION \ + "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \ + "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n" + #else + #define MLOCK_SUGGESTION \ + "Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n" + #endif + + bool raw_lock(const void * addr, size_t size) const { + if (!mlock(addr, size)) { + return true; + } + + char* errmsg = std::strerror(errno); + bool suggest = (errno == ENOMEM); + + // Check if the resource limit is fine after all + struct rlimit lock_limit; + if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) { + suggest = false; + } + if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) { + suggest = false; + } + + fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s", + size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : ""); + return false; + } + + #undef MLOCK_SUGGESTION + + static void raw_unlock(void * addr, size_t size) { + if (munlock(addr, size)) { + fprintf(stderr, "warning: failed to munlock buffer: %s\n", std::strerror(errno)); + } + } +#elif defined(_WIN32) + static constexpr bool SUPPORTED = true; + + static size_t lock_granularity() { + SYSTEM_INFO si; + GetSystemInfo(&si); + return (size_t) si.dwPageSize; + } + + bool raw_lock(void * ptr, size_t len) const { + for (int tries = 1; ; tries++) { + if (VirtualLock(ptr, len)) { + return true; + } + if (tries == 2) { + fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n", + len, size, llama_format_win_err(GetLastError()).c_str()); + return false; + } + + // It failed but this was only the first try; increase the working + // set size and try again. + SIZE_T min_ws_size, max_ws_size; + if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) { + fprintf(stderr, "warning: GetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + // Per MSDN: "The maximum number of pages that a process can lock + // is equal to the number of pages in its minimum working set minus + // a small overhead." + // Hopefully a megabyte is enough overhead: + size_t increment = len + 1048576; + // The minimum must be <= the maximum, so we need to increase both: + min_ws_size += increment; + max_ws_size += increment; + if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) { + fprintf(stderr, "warning: SetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + } + } + + static void raw_unlock(void * ptr, size_t len) { + if (!VirtualUnlock(ptr, len)) { + fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + static constexpr bool SUPPORTED = false; + + static size_t lock_granularity() { + return (size_t) 65536; + } + + bool raw_lock(const void * addr, size_t len) const { + fprintf(stderr, "warning: mlock not supported on this system\n"); + return false; + } + + static void raw_unlock(const void * addr, size_t len) {} +#endif +}; + +typedef void (*offload_func_t)(struct ggml_tensor * tensor); + +static void ggml_offload_nop(struct ggml_tensor * tensor) { + (void) tensor; +} + +static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) { + std::vector result(8, 0); + const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); + GGML_ASSERT(check == -n_tokens); + } + else { + result.resize(n_tokens); + } + + return std::string(result.data(), result.size()); +} + +// +// globals +// + +struct llama_state { + llama_state() { +#ifdef GGML_USE_METAL + ggml_metal_log_set_callback(log_callback, log_callback_user_data); +#endif + } + + // We save the log callback globally + ggml_log_callback log_callback = llama_log_callback_default; + void * log_callback_user_data = nullptr; +}; + +static llama_state g_state; + +// available llama models +enum e_model { + MODEL_UNKNOWN, + MODEL_1B, + MODEL_3B, + MODEL_7B, + MODEL_8B, + MODEL_13B, + MODEL_15B, + MODEL_30B, + MODEL_34B, + MODEL_40B, + MODEL_65B, + MODEL_70B, +}; + +static const size_t kiB = 1024; +static const size_t MiB = 1024*kiB; +static const size_t GiB = 1024*MiB; + +struct llama_hparams { + bool vocab_only; + uint32_t n_vocab; + uint32_t n_ctx_train; // context size the model was trained on + uint32_t n_embd; + uint32_t n_head; + uint32_t n_head_kv; + uint32_t n_layer; + uint32_t n_rot; + uint32_t n_ff; + uint32_t n_expert = 0; + uint32_t n_expert_used = 0; + + float f_norm_eps; + float f_norm_rms_eps; + + float rope_freq_base_train; + float rope_freq_scale_train; + uint32_t n_yarn_orig_ctx; + int8_t rope_scaling_type_train : 3; + bool rope_finetuned : 1; + + float f_clamp_kqv; + float f_max_alibi_bias; + + bool operator!=(const llama_hparams & other) const { + if (this->vocab_only != other.vocab_only) return true; + if (this->n_vocab != other.n_vocab) return true; + if (this->n_ctx_train != other.n_ctx_train) return true; + if (this->n_embd != other.n_embd) return true; + if (this->n_head != other.n_head) return true; + if (this->n_head_kv != other.n_head_kv) return true; + if (this->n_layer != other.n_layer) return true; + if (this->n_rot != other.n_rot) return true; + if (this->n_ff != other.n_ff) return true; + if (this->n_expert != other.n_expert) return true; + if (this->n_expert_used != other.n_expert_used) return true; + + if (this->rope_finetuned != other.rope_finetuned) return true; + if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; + + const float EPSILON = 1e-9; + + if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; + if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; + if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; + if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; + + return false; + } + + uint32_t n_gqa() const { + return n_head/n_head_kv; + } + + uint32_t n_embd_head() const { + return n_embd/n_head; + } + + uint32_t n_embd_gqa() const { + return n_embd/n_gqa(); + } +}; + +struct llama_cparams { + uint32_t n_ctx; // context size used during inference + uint32_t n_batch; + uint32_t n_threads; // number of threads to use for generation + uint32_t n_threads_batch; // number of threads to use for batch processing + + float rope_freq_base; + float rope_freq_scale; + + uint32_t n_yarn_orig_ctx; + // These hyperparameters are not exposed in GGUF, because all + // existing YaRN models use the same values for them. + float yarn_ext_factor; + float yarn_attn_factor; + float yarn_beta_fast; + float yarn_beta_slow; + + bool mul_mat_q; + bool offload_kqv; +}; + +struct llama_layer { + // normalization + struct ggml_tensor * attn_norm; + struct ggml_tensor * attn_norm_b; + struct ggml_tensor * attn_norm_2; + struct ggml_tensor * attn_norm_2_b; + struct ggml_tensor * attn_q_norm; + struct ggml_tensor * attn_q_norm_b; + struct ggml_tensor * attn_k_norm; + struct ggml_tensor * attn_k_norm_b; + + // attention + struct ggml_tensor * wq; + struct ggml_tensor * wk; + struct ggml_tensor * wv; + struct ggml_tensor * wo; + struct ggml_tensor * wqkv; + + // attention bias + struct ggml_tensor * bq; + struct ggml_tensor * bk; + struct ggml_tensor * bv; + struct ggml_tensor * bo; + struct ggml_tensor * bqkv; + + // normalization + struct ggml_tensor * ffn_norm; + struct ggml_tensor * ffn_norm_b; + + // ff + struct ggml_tensor * ffn_gate; // w1 + struct ggml_tensor * ffn_down; // w2 + struct ggml_tensor * ffn_up; // w3 + + // ff MoE + struct ggml_tensor * ffn_gate_inp; + struct ggml_tensor * ffn_gate_exp[LLAMA_MAX_EXPERTS]; + struct ggml_tensor * ffn_down_exp[LLAMA_MAX_EXPERTS]; + struct ggml_tensor * ffn_up_exp [LLAMA_MAX_EXPERTS]; + + // ff bias + struct ggml_tensor * ffn_down_b; // b2 + struct ggml_tensor * ffn_up_b; // b3 +}; + +struct llama_kv_cell { + llama_pos pos = -1; + llama_pos delta = 0; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } +}; + +// ring-buffer of cached KV data +struct llama_kv_cache { + bool has_shift = false; + + // Note: The value of head isn't only used to optimize searching + // for a free KV slot. llama_decode_internal also uses it, so it + // cannot be freely changed after a slot has been allocated. + uint32_t head = 0; + uint32_t size = 0; + uint32_t used = 0; // used cells (i.e. at least one seq_id) + + // computed before each graph build + uint32_t n = 0; + + std::vector cells; + + std::vector k_l; // per layer + std::vector v_l; + + struct ggml_context * ctx = NULL; + + llama_buffer buf; + + ~llama_kv_cache() { + if (ctx) { + ggml_free(ctx); + } + +#ifdef GGML_USE_CUBLAS + if (ggml_cublas_loaded()) { + for (size_t i = 0; i < k_l.size(); ++i) { + ggml_cuda_free_data(k_l[i]); + ggml_cuda_free_data(v_l[i]); + } + } +#endif + } +}; + +struct llama_vocab { + using id = int32_t; + using token = std::string; + using ttype = llama_token_type; + + struct token_data { + token text; + float score; + ttype type; + }; + + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + + std::unordered_map token_to_id; + std::vector id_to_token; + + std::unordered_map special_tokens_cache; + + std::map, int> bpe_ranks; + + // default LLaMA special tokens + id special_bos_id = 1; + id special_eos_id = 2; + id special_unk_id = 0; + id special_sep_id = -1; + id special_pad_id = -1; + + int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add. + int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add. + + id linefeed_id = 13; + id special_prefix_id = 32007; + id special_middle_id = 32009; + id special_suffix_id = 32008; + id special_eot_id = 32010; + + int find_bpe_rank(std::string token_left, std::string token_right) const { + GGML_ASSERT(token_left.find(" ") == std::string::npos); + GGML_ASSERT(token_left.find("\n") == std::string::npos); + GGML_ASSERT(token_right.find(" ") == std::string::npos); + GGML_ASSERT(token_right.find("\n") == std::string::npos); + + auto it = bpe_ranks.find(std::make_pair(token_left, token_right)); + if (it == bpe_ranks.end()) { + return -1; + } + + return it->second; + } +}; + +struct llama_model { + e_model type = MODEL_UNKNOWN; + llm_arch arch = LLM_ARCH_UNKNOWN; + llama_ftype ftype = LLAMA_FTYPE_ALL_F32; + + std::string name = "n/a"; + + llama_hparams hparams = {}; + llama_vocab vocab; + + struct ggml_tensor * tok_embd; + struct ggml_tensor * pos_embd; + struct ggml_tensor * tok_norm; + struct ggml_tensor * tok_norm_b; + + struct ggml_tensor * output_norm; + struct ggml_tensor * output_norm_b; + struct ggml_tensor * output; + + std::vector layers; + + int n_gpu_layers; + + // gguf metadata + std::unordered_map gguf_kv; + + // context + struct ggml_context * ctx = NULL; + + // the model memory buffer + llama_buffer buf; + + // model memory mapped file + std::unique_ptr mapping; + + // objects representing data potentially being locked in memory + llama_mlock mlock_buf; + llama_mlock mlock_mmap; + + // for quantize-stats only + std::vector> tensors_by_name; + + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + ~llama_model() { + if (ctx) { + ggml_free(ctx); + } + +#ifdef GGML_USE_CUBLAS + if (ggml_cublas_loaded()) { + for (size_t i = 0; i < tensors_by_name.size(); ++i) { + ggml_cuda_free_data(tensors_by_name[i].second); + } + ggml_cuda_free_scratch(); + } +#endif + +#if defined(GGML_USE_CLBLAST) + for (size_t i = 0; i < tensors_by_name.size(); ++i) { + ggml_cl_free_data(tensors_by_name[i].second); + } +#endif + } +}; + +struct llama_context { + llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {} + ~llama_context() { +#ifdef GGML_USE_METAL + if (ctx_metal) { + ggml_metal_free(ctx_metal); + } +#endif + if (alloc) { + ggml_allocr_free(alloc); + } + } + + llama_cparams cparams; + + const llama_model & model; + + // key + value cache for the self attention + struct llama_kv_cache kv_self; + + std::mt19937 rng; + + bool has_evaluated_once = false; + + int64_t t_start_us; + int64_t t_load_us; + int64_t t_sample_us = 0; + int64_t t_p_eval_us = 0; + int64_t t_eval_us = 0; + + int32_t n_sample = 0; // number of tokens sampled + int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) + int32_t n_eval = 0; // number of eval calls + + // decode output (2-dimensional array: [n_tokens][n_vocab]) + std::vector logits; +#ifndef NDEBUG + // guard against access to unset logits + std::vector logits_valid; +#endif + bool logits_all = false; + + // input embedding (1-dimensional array: [n_embd]) + std::vector embedding; + + // reusable buffer for `struct ggml_graph_plan.work_data` + std::vector work_buffer; + + // memory buffers used to evaluate the model + llama_buffer buf_compute; + + llama_buffer buf_alloc; + ggml_allocr * alloc = NULL; + +#ifdef GGML_USE_METAL + ggml_metal_context * ctx_metal = NULL; +#endif + +#ifdef GGML_USE_MPI + ggml_mpi_context * ctx_mpi = NULL; +#endif +}; + +// +// kv cache helpers +// + +static bool llama_kv_cache_init( + const struct llama_hparams & hparams, + struct llama_kv_cache & cache, + ggml_type ktype, + ggml_type vtype, + uint32_t n_ctx, + int n_gpu_layers, + bool offload) { + const uint32_t n_embd = hparams.n_embd_gqa(); + const uint32_t n_layer = hparams.n_layer; + + const int64_t n_mem = n_layer*n_ctx; + const int64_t n_elements = n_embd*n_mem; + + cache.has_shift = false; + + cache.head = 0; + cache.size = n_ctx; + cache.used = 0; + + cache.cells.clear(); + cache.cells.resize(n_ctx); + + cache.buf.resize(ggml_row_size(ktype, n_elements) + ggml_row_size(vtype, n_elements) + 2u*n_layer*ggml_tensor_overhead()); + memset(cache.buf.data, 0, cache.buf.size); + + struct ggml_init_params params; + params.mem_size = cache.buf.size; + params.mem_buffer = cache.buf.data; + params.no_alloc = false; + + cache.ctx = ggml_init(params); + + size_t vram_kv_cache = 0; + + if (!cache.ctx) { + LLAMA_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + cache.k_l.reserve(n_layer); + cache.v_l.reserve(n_layer); + + const int i_gpu_start = (int) n_layer - n_gpu_layers; GGML_UNUSED(i_gpu_start); + + GGML_UNUSED(offload); + + for (int i = 0; i < (int) n_layer; i++) { + ggml_tensor * k = ggml_new_tensor_1d(cache.ctx, ktype, n_embd*n_ctx); + ggml_tensor * v = ggml_new_tensor_1d(cache.ctx, vtype, n_embd*n_ctx); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + cache.k_l.push_back(k); + cache.v_l.push_back(v); +#ifdef GGML_USE_CUBLAS + if (i >= i_gpu_start) { + if (offload) { + ggml_cuda_assign_buffers_no_scratch(k); + vram_kv_cache += ggml_nbytes(k); + ggml_cuda_assign_buffers_no_scratch(v); + vram_kv_cache += ggml_nbytes(v); + } + } +#endif // GGML_USE_CUBLAS + } + + if (vram_kv_cache > 0) { + LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0); + } + + GGML_UNUSED(n_gpu_layers); + + return true; +} + +// find an empty slot of size "n_tokens" in the cache +// updates the cache head +// Note: On success, it's important that cache.head points +// to the first cell of the slot. +static bool llama_kv_cache_find_slot( + struct llama_kv_cache & cache, + const struct llama_batch & batch) { + const uint32_t n_ctx = cache.size; + const uint32_t n_tokens = batch.n_tokens; + + if (n_tokens > n_ctx) { + LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); + return false; + } + + uint32_t n_tested = 0; + + while (true) { + if (cache.head + n_tokens > n_ctx) { + n_tested += n_ctx - cache.head; + cache.head = 0; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cache.cells[cache.head + i].pos >= 0) { + found = false; + cache.head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= n_ctx) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } + } + + for (uint32_t i = 0; i < n_tokens; i++) { + cache.cells[cache.head + i].pos = batch.pos[i]; + + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); + } + } + + cache.used += n_tokens; + + return true; +} + +// find how many cells are currently in use +static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { + for (uint32_t i = cache.size - 1; i > 0; --i) { + if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) { + return i + 1; + } + } + + return 0; +} + +static void llama_kv_cache_clear(struct llama_kv_cache & cache) { + for (int32_t i = 0; i < (int32_t) cache.size; ++i) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + } + cache.head = 0; + cache.used = 0; +} + +static void llama_kv_cache_seq_rm( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + uint32_t new_head = cache.size; + + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + if (seq_id < 0) { + cache.cells[i].seq_id.clear(); + } else if (cache.cells[i].has_seq_id(seq_id)) { + cache.cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cache.cells[i].seq_id.empty()) { + // keep count of the number of used cells + if (cache.cells[i].pos >= 0) cache.used--; + + cache.cells[i].pos = -1; + if (new_head == cache.size) new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size && new_head < cache.head) cache.head = new_head; +} + +static void llama_kv_cache_seq_cp( + struct llama_kv_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + cache.head = 0; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.cells[i].seq_id.insert(seq_id_dst); + } + } +} + +static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { + uint32_t new_head = cache.size; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (!cache.cells[i].has_seq_id(seq_id)) { + if (cache.cells[i].pos >= 0) cache.used--; + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + if (new_head == cache.size) new_head = i; + } else { + cache.cells[i].seq_id.clear(); + cache.cells[i].seq_id.insert(seq_id); + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size && new_head < cache.head) cache.head = new_head; +} + +static void llama_kv_cache_seq_shift( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + uint32_t new_head = cache.size; + + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.has_shift = true; + cache.cells[i].pos += delta; + cache.cells[i].delta += delta; + + if (cache.cells[i].pos < 0) { + if (!cache.cells[i].seq_id.empty()) cache.used--; + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + if (new_head == cache.size) new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.head = new_head != cache.size ? new_head : 0; +} + +// +// model loading and saving +// + +enum llama_fver { + GGUF_FILE_VERSION_V1 = 1, + GGUF_FILE_VERSION_V2 = 2, + GGUF_FILE_VERSION_V3 = 3, +}; + +static const char * llama_file_version_name(llama_fver version) { + switch (version) { + case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)"; + case GGUF_FILE_VERSION_V2: return "GGUF V2"; + case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)"; + } + + return "unknown"; +} + +static std::string llama_format_tensor_shape(const std::vector & ne) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0)); + for (size_t i = 1; i < ne.size(); i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i)); + } + return buf; +} + +static std::string llama_format_tensor_shape(const struct ggml_tensor * t) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]); + } + return buf; +} + +namespace GGUFMeta { + template + struct GKV_Base_Type { + static constexpr gguf_type gt = gt_; + + static T getter(const gguf_context * ctx, const int kid) { + return gfun(ctx, kid); + } + }; + + template struct GKV_Base; + + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + + template<> struct GKV_Base { + static constexpr gguf_type gt = GGUF_TYPE_STRING; + + static std::string getter(const gguf_context * ctx, const int kid) { + return gguf_get_val_str(ctx, kid); + } + }; + + struct ArrayInfo{ + const gguf_type gt; + const size_t length; + const void * data; + }; + + template<> struct GKV_Base { + public: + static constexpr gguf_type gt = GGUF_TYPE_ARRAY; + static ArrayInfo getter(const gguf_context *ctx, const int k) { + return ArrayInfo { + gguf_get_arr_type(ctx, k), + size_t(gguf_get_arr_n(ctx, k)), + gguf_get_arr_data(ctx, k), + }; + } + }; + + template + class GKV: public GKV_Base { + GKV() = delete; + + public: + static T get_kv(const gguf_context * ctx, const int k) { + const enum gguf_type kt = gguf_get_kv_type(ctx, k); + + if (kt != GKV::gt) { + throw std::runtime_error(format("key %s has wrong type %s but expected type %s", + gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt))); + } + return GKV::getter(ctx, k); + } + + static const char * override_type_to_str(const llama_model_kv_override_type ty) { + switch (ty) { + case LLAMA_KV_OVERRIDE_BOOL: return "bool"; + case LLAMA_KV_OVERRIDE_INT: return "int"; + case LLAMA_KV_OVERRIDE_FLOAT: return "float"; + } + return "unknown"; + } + + static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override *override) { + if (!override) { return false; } + if (override->tag == expected_type) { + LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ", + __func__, override_type_to_str(override->tag), override->key); + switch (override->tag) { + case LLAMA_KV_OVERRIDE_BOOL: { + printf("%s\n", override->bool_value ? "true" : "false"); + } break; + case LLAMA_KV_OVERRIDE_INT: { + printf("%" PRId64 "\n", override->int_value); + } break; + case LLAMA_KV_OVERRIDE_FLOAT: { + printf("%.6f\n", override->float_value); + } break; + default: + // Shouldn't be possible to end up here, but just in case... + throw std::runtime_error( + format("Unsupported attempt to override %s type for metadata key %s\n", + override_type_to_str(override->tag), override->key)); + } + return true; + } + LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n", + __func__, override->key, override_type_to_str(expected_type), override_type_to_str(override->tag)); + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(OT & target, const struct llama_model_kv_override *override) { + if (validate_override(LLAMA_KV_OVERRIDE_BOOL, override)) { + target = override->bool_value; + return true; + } + return false; + } + + template + static typename std::enable_if::value && std::is_integral::value, bool>::type + try_override(OT & target, const struct llama_model_kv_override *override) { + if (validate_override(LLAMA_KV_OVERRIDE_INT, override)) { + target = override->int_value; + return true; + } + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(T & target, const struct llama_model_kv_override *override) { + if (validate_override(LLAMA_KV_OVERRIDE_FLOAT, override)) { + target = override->float_value; + return true; + } + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(T & target, const struct llama_model_kv_override *override) { + (void)target; + (void)override; + if (!override) { return false; } + // Currently, we should never end up here so it would be a bug if we do. + throw std::runtime_error(format("Unsupported attempt to override string type for metadata key %s\n", + override ? override->key : "NULL")); + } + + static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override *override = nullptr) { + if (try_override(target, override)) { + return true; + } + if (k < 0) { return false; } + target = get_kv(ctx, k); + return true; + } + + static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override *override = nullptr) { + return set(ctx, gguf_find_key(ctx, key), target, override); + } + + static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override *override = nullptr) { + return set(ctx, key.c_str(), target, override); + } + }; +} + +struct llama_model_loader { + int n_kv = 0; + int n_tensors = 0; + int n_created = 0; + + int64_t n_elements = 0; + size_t n_bytes = 0; + + bool use_mmap = false; + + llama_file file; + llama_ftype ftype; + llama_fver fver; + + std::unique_ptr mapping; + std::unordered_map kv_overrides; + + struct gguf_context * ctx_gguf = NULL; + struct ggml_context * ctx_meta = NULL; + + std::string arch_name; + LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); + + llama_model_loader(const std::string & fname, bool use_mmap, const struct llama_model_kv_override * param_overrides_p) : file(fname.c_str(), "rb") { + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx_meta, + }; + + if (param_overrides_p != nullptr) { + for (const struct llama_model_kv_override *p = param_overrides_p; p->key[0] != 0; p++) { + kv_overrides.insert({std::string(p->key), *p}); + } + } + + ctx_gguf = gguf_init_from_file(fname.c_str(), params); + if (!ctx_gguf) { + throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str())); + } + + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); + + n_kv = gguf_get_n_kv(ctx_gguf); + n_tensors = gguf_get_n_tensors(ctx_gguf); + + fver = (enum llama_fver ) gguf_get_version(ctx_gguf); + + for (int i = 0; i < n_tensors; i++) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + struct ggml_tensor * t = ggml_get_tensor(ctx_meta, name); + n_elements += ggml_nelements(t); + n_bytes += ggml_nbytes(t); + } + + LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", + __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); + + // determine file type based on the number of tensors for each quantization and print meta data + // TODO: make optional + { + std::map n_type; + + uint32_t n_type_max = 0; + enum ggml_type type_max = GGML_TYPE_F32; + + for (int i = 0; i < n_tensors; i++) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + struct ggml_tensor * meta = ggml_get_tensor(ctx_meta, name); + + n_type[meta->type]++; + + if (n_type_max < n_type[meta->type]) { + n_type_max = n_type[meta->type]; + type_max = meta->type; + } + + LLAMA_LOG_INFO("%s: - tensor %4d: %32s %-8s [ %s ]\n", __func__, i, name, ggml_type_name(meta->type), llama_format_tensor_shape(meta).c_str()); + } + + switch (type_max) { + case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break; + case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break; + case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break; + case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break; + case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break; + case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break; + case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break; + case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break; + case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break; + case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break; + case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break; + case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; + default: + { + LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); + ftype = LLAMA_FTYPE_ALL_F32; + } break; + } + + // this is a way to mark that we have "guessed" the file type + ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED); + + { + const int kid = gguf_find_key(ctx_gguf, "general.file_type"); + if (kid >= 0) { + ftype = (llama_ftype) gguf_get_val_u32(ctx_gguf, kid); + } + } + + LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__); + for (int i = 0; i < n_kv; i++) { + const char * name = gguf_get_key(ctx_gguf, i); + const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); + const std::string type_name = + type == GGUF_TYPE_ARRAY + ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(ctx_gguf, i)), gguf_get_arr_n(ctx_gguf, i)) + : gguf_type_name(type); + + std::string value = gguf_kv_to_str(ctx_gguf, i); + const size_t MAX_VALUE_LEN = 40; + if (value.size() > MAX_VALUE_LEN) { + value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()); + } + replace_all(value, "\n", "\\n"); + + LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str()); + } + + // print type counts + for (auto & kv : n_type) { + if (kv.second == 0) { + continue; + } + + LLAMA_LOG_INFO("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second); + } + } + + if (!llama_mmap::SUPPORTED) { + LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__); + use_mmap = false; + } + + this->use_mmap = use_mmap; + } + + ~llama_model_loader() { + if (ctx_gguf) { + gguf_free(ctx_gguf); + } + if (ctx_meta) { + ggml_free(ctx_meta); + } + } + + template + typename std::enable_if::value, bool>::type + get_arr_n(const std::string & key, T & result, const bool required = true) { + const int kid = gguf_find_key(ctx_gguf, key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(ctx_gguf, kid); + + + result = arr_info.length; + return true; + } + + template + typename std::enable_if::value, bool>::type + get_arr_n(const enum llm_kv kid, T & result, const bool required = true) { + return get_arr_n(llm_kv(kid), result, required); + } + + template + bool get_key(const std::string & key, T & result, const bool required = true) { + auto it = kv_overrides.find(key); + + const struct llama_model_kv_override * override = + it != kv_overrides.end() ? &it->second : nullptr; + + const bool found = GGUFMeta::GKV::set(ctx_gguf, key, result, override); + + if (required && !found) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + + return found; + } + + template + bool get_key(const enum llm_kv kid, T & result, const bool required = true) { + return get_key(llm_kv(kid), result, required); + } + + std::string get_arch_name() const { + return arch_name; + } + + enum llm_arch get_arch() const { + return llm_kv.arch; + } + + const char * get_tensor_name(int i) const { + return gguf_get_tensor_name(ctx_gguf, i); + } + + struct ggml_tensor * get_tensor_meta(int i) const { + return ggml_get_tensor(ctx_meta, get_tensor_name(i)); + } + + void calc_sizes(size_t & ctx_size_p, size_t & mmapped_size_p) const { + ctx_size_p = 0; + mmapped_size_p = 0; + + for (int i = 0; i < n_tensors; i++) { + struct ggml_tensor * meta = get_tensor_meta(i); + ctx_size_p += sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE; + (use_mmap ? mmapped_size_p : ctx_size_p) += ggml_nbytes_pad(meta); + } + } + + struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, struct ggml_tensor * meta, ggml_backend_type backend) { + if (backend != GGML_BACKEND_CPU) { + ggml_set_no_alloc(ctx, true); + } + + struct ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); + tensor->backend = backend; // TODO: ggml_set_backend + ggml_set_name(tensor, ggml_get_name(meta)); + + if (backend != GGML_BACKEND_CPU) { + ggml_set_no_alloc(ctx, use_mmap); + } + + n_created++; + + return tensor; + } + + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, ggml_backend_type backend, bool required = true) { + struct ggml_tensor * cur = ggml_get_tensor(ctx_meta, name.c_str()); + + if (cur == NULL) { + if (!required) { + return NULL; + } + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str())); + } + + if (backend == GGML_BACKEND_GPU_SPLIT) { + if (ne.size() == 1) { + throw std::runtime_error(format("%s: 1-dimensional tensor '%s' cannot be split on the GPU", __func__, name.c_str())); + } + } + + { + bool is_ok = true; + for (size_t i = 0; i < ne.size(); ++i) { + if (ne[i] != cur->ne[i]) { + is_ok = false; + break; + } + } + if (!is_ok) { + throw std::runtime_error( + format("%s: tensor '%s' has wrong shape; expected %s, got %s", + __func__, name.c_str(), + llama_format_tensor_shape(ne).c_str(), + llama_format_tensor_shape(cur).c_str())); + } + } + + return create_tensor_for(ctx, cur, backend); + } + + void done_getting_tensors() const { + if (n_created != n_tensors) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } + } + + size_t file_offset(const char * name) const { + const int idx = gguf_find_tensor(ctx_gguf, name); + + if (idx < 0) { + throw std::runtime_error(format("%s: tensor '%s' not found in the file", __func__, name)); + } + + return gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, idx); + } + + void load_data_for(struct ggml_tensor * cur) const { + const size_t offs = file_offset(ggml_get_name(cur)); + + if (use_mmap) { + cur->data = (uint8_t *) mapping->addr + offs; + } else { + file.seek(offs, SEEK_SET); + file.read_raw(cur->data, ggml_nbytes(cur)); + } + } + + void load_all_data(struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) { + size_t size_data = 0; + size_t size_lock = 0; + size_t size_pref = 0; // prefetch + + for (int i = 0; i < gguf_get_n_tensors(ctx_gguf); i++) { + struct ggml_tensor * cur = ggml_get_tensor(ctx, gguf_get_tensor_name(ctx_gguf, i)); + size_data += ggml_nbytes(cur); + if (cur->backend == GGML_BACKEND_CPU) { + size_pref += ggml_nbytes(cur); + } + } + + if (use_mmap) { + mapping.reset(new llama_mmap(&file, size_pref, ggml_is_numa())); + if (lmlock) { + lmlock->init(mapping->addr); + } + } + + size_t done_size = 0; + for (int i = 0; i < gguf_get_n_tensors(ctx_gguf); i++) { + struct ggml_tensor * cur = ggml_get_tensor(ctx, gguf_get_tensor_name(ctx_gguf, i)); + GGML_ASSERT(cur); // unused tensors should have been caught by load_data already + + if (progress_callback) { + progress_callback((float) done_size / size_data, progress_callback_user_data); + } + + // allocate temp buffer if not using mmap + if (!use_mmap && cur->data == NULL) { + GGML_ASSERT(cur->backend != GGML_BACKEND_CPU); + #ifdef GGML_USE_CPU_HBM + cur->data = (uint8_t*)hbw_malloc(ggml_nbytes(cur)); + #else + cur->data = (uint8_t*)malloc(ggml_nbytes(cur)); + #endif + } + + load_data_for(cur); + + switch (cur->backend) { + case GGML_BACKEND_CPU: + if (use_mmap && lmlock) { + size_lock += ggml_nbytes(cur); + lmlock->grow_to(size_lock); + } + break; +#ifdef GGML_USE_CUBLAS + case GGML_BACKEND_GPU: + case GGML_BACKEND_GPU_SPLIT: + // old code: + //ggml_cuda_transform_tensor(lt.data, lt.ggml_tensor); + + // TODO: test if this works !! + ggml_cuda_transform_tensor(cur->data, cur); + if (!use_mmap) { + free(cur->data); + } + break; +#elif defined(GGML_USE_CLBLAST) + case GGML_BACKEND_GPU: + ggml_cl_transform_tensor(cur->data, cur); + if (!use_mmap) { + free(cur->data); + } + break; +#endif + default: + continue; + } + + done_size += ggml_nbytes(cur); + } + } +}; + +// +// load LLaMA models +// + +static std::string llama_model_arch_name(llm_arch arch) { + auto it = LLM_ARCH_NAMES.find(arch); + if (it == LLM_ARCH_NAMES.end()) { + return "unknown"; + } + return it->second; +} + +static std::string llama_model_ftype_name(llama_ftype ftype) { + if (ftype & LLAMA_FTYPE_GUESSED) { + return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; + } + + switch (ftype) { + case LLAMA_FTYPE_ALL_F32: return "all F32"; + case LLAMA_FTYPE_MOSTLY_F16: return "F16"; + case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; + case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; + case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: + return "Q4_1, some F16"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; + case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; + case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + + // K-quants + case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K"; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; + + default: return "unknown, may not work"; + } +} + +static const char * llama_model_type_name(e_model type) { + switch (type) { + case MODEL_1B: return "1B"; + case MODEL_3B: return "3B"; + case MODEL_7B: return "7B"; + case MODEL_8B: return "8B"; + case MODEL_13B: return "13B"; + case MODEL_15B: return "15B"; + case MODEL_30B: return "30B"; + case MODEL_34B: return "34B"; + case MODEL_40B: return "40B"; + case MODEL_65B: return "65B"; + case MODEL_70B: return "70B"; + default: return "?B"; + } +} + +static void llm_load_arch(llama_model_loader & ml, llama_model & model) { + model.arch = ml.get_arch(); + if (model.arch == LLM_ARCH_UNKNOWN) { + throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); + } +} + +static void llm_load_hparams( + llama_model_loader & ml, + llama_model & model) { + auto & hparams = model.hparams; + const gguf_context * ctx = ml.ctx_gguf; + + // get metadata as string + for (int i = 0; i < gguf_get_n_kv(ctx); i++) { + enum gguf_type type = gguf_get_kv_type(ctx, i); + if (type == GGUF_TYPE_ARRAY) { + continue; + } + const char * name = gguf_get_key(ctx, i); + const std::string value = gguf_kv_to_str(ctx, i); + model.gguf_kv.emplace(name, value); + } + + // get general kv + ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); + + // get hparams kv + ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); + ml.get_key (LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); + ml.get_key (LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff); + ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head); + ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer); + ml.get_key (LLM_KV_EXPERT_COUNT, hparams.n_expert, false); + ml.get_key (LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + + GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS); + GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); + if (hparams.n_expert > 0) { + GGML_ASSERT(hparams.n_expert_used > 0); + } else { + GGML_ASSERT(hparams.n_expert_used == 0); + } + + // n_head_kv is optional, default to n_head + hparams.n_head_kv = hparams.n_head; + ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false); + + bool rope_finetuned = false; + ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); + hparams.rope_finetuned = rope_finetuned; + + hparams.n_yarn_orig_ctx = hparams.n_ctx_train; + ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_yarn_orig_ctx, false); + + // rope_freq_base (optional) + hparams.rope_freq_base_train = 10000.0f; + ml.get_key(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train, false); + + std::string rope_scaling("linear"); + ml.get_key(LLM_KV_ROPE_SCALING_TYPE, rope_scaling, false); + hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling); + GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED); + + // rope_freq_scale (inverse of the kv) is optional + float ropescale = 0.0f; + if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) { + // try the old key name + ml.get_key(LLM_KV_ROPE_SCALE_LINEAR, ropescale, false); + } + hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; + + // sanity check for n_rot (optional) + { + hparams.n_rot = hparams.n_embd / hparams.n_head; + + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); + + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) { + if (hparams.n_rot != hparams.n_embd / hparams.n_head) { + throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head)); + } + } + // gpt-neox n_rot = rotary_pct * (n_embd / n_head) + // gpt-j n_rot = rotary_dim + } + + // arch-specific KVs + switch (model.arch) { + case LLM_ARCH_LLAMA: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 22: model.type = e_model::MODEL_1B; break; + case 26: model.type = e_model::MODEL_3B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + case 48: model.type = e_model::MODEL_34B; break; + case 60: model.type = e_model::MODEL_30B; break; + case 80: model.type = hparams.n_head == hparams.n_head_kv ? e_model::MODEL_65B : e_model::MODEL_70B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_FALCON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 60: model.type = e_model::MODEL_40B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_BAICHUAN: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_STARCODER: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; + case 36: model.type = e_model::MODEL_3B; break; + case 42: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_15B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_PERSIMMON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 36: model.type = e_model::MODEL_8B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_REFACT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_1B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_BLOOM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; + case 30: + switch (hparams.n_embd) { + case 2560: model.type = e_model::MODEL_3B; break; + case 4096: model.type = e_model::MODEL_7B; break; + } break; + } + } break; + case LLM_ARCH_MPT: + { + hparams.f_clamp_kqv = 0.0f; + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 48: model.type = e_model::MODEL_30B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_STABLELM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + + default: (void)0; + } + + model.ftype = ml.ftype; +} + +// TODO: This should probably be in llama.h +static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special = false); +static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch); + +static void llm_load_vocab( + llama_model_loader & ml, + llama_model & model) { + auto & vocab = model.vocab; + + struct gguf_context * ctx = ml.ctx_gguf; + + const auto kv = LLM_KV(model.arch); + + const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); + if (token_idx == -1) { + throw std::runtime_error("cannot find tokenizer vocab in model file\n"); + } + + const float * scores = nullptr; + const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); + if (score_idx != -1) { + scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + } + + const int * toktypes = nullptr; + const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); + if (toktype_idx != -1) { + toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + } + + // determine vocab type + { + std::string tokenizer_name; + + ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name); + + if (tokenizer_name == "llama") { + vocab.type = LLAMA_VOCAB_TYPE_SPM; + + // default special tokens + vocab.special_bos_id = 1; + vocab.special_eos_id = 2; + vocab.special_unk_id = 0; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; + } else if (tokenizer_name == "gpt2") { + vocab.type = LLAMA_VOCAB_TYPE_BPE; + + // read bpe merges and populate bpe ranks + const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + if (merges_keyidx == -1) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + GGML_ASSERT(codepoints_from_utf8(word).size() > 0); + + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + vocab.bpe_ranks.emplace(std::make_pair(first, second), i); + } + + // default special tokens + vocab.special_bos_id = 11; + vocab.special_eos_id = 11; + vocab.special_unk_id = -1; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; + } else { + LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); + LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); + + vocab.type = LLAMA_VOCAB_TYPE_SPM; + } + } + + const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); + + vocab.id_to_token.resize(n_vocab); + + for (uint32_t i = 0; i < n_vocab; i++) { + std::string word = gguf_get_arr_str(ctx, token_idx, i); + GGML_ASSERT(codepoints_from_utf8(word).size() > 0); + + vocab.token_to_id[word] = i; + + auto & token_data = vocab.id_to_token[i]; + token_data.text = std::move(word); + token_data.score = scores ? scores[i] : 0.0f; + token_data.type = toktypes ? (llama_token_type) toktypes[i] : LLAMA_TOKEN_TYPE_NORMAL; + } + GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); + + // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' + if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { + vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); + } else { + const std::vector ids = llama_tokenize_internal(vocab, "\u010A", false); + GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + vocab.linefeed_id = ids[0]; + } + + // special tokens + { + const std::vector> special_token_types = { + { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id }, + { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id }, + { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id }, + { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id }, + { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id }, + }; + for (const auto & it : special_token_types) { + const std::string & key = kv(std::get<0>(it)); + int32_t & id = std::get<1>(it); + + uint32_t new_id; + if (!ml.get_key(std::get<0>(it), new_id, false)) { + continue; + } + if (new_id >= vocab.id_to_token.size()) { + LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n", + __func__, key.c_str(), new_id, id); + } else { + id = new_id; + } + + } + + // Handle add_bos_token and add_eos_token + { + bool temp = true; + + if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) { + vocab.special_add_bos = int(temp); + } + if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) { + vocab.special_add_eos = int(temp); + } + } + } + + // build special tokens cache + { + // TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type, + // and will always be correctly labeled in 'added_tokens.json' etc. + // The assumption is, since special tokens aren't meant to be exposed to end user, they are designed + // to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer + // are special tokens. + // From testing, this appears to correlate 1:1 with special tokens. + // + + // Counting special tokens and verifying in only one direction + // is sufficient to detect difference in those two sets. + // + uint32_t special_tokens_count_by_type = 0; + uint32_t special_tokens_count_from_verification = 0; + + bool special_tokens_definition_mismatch = false; + + for (const auto & t : vocab.token_to_id) { + const auto & token = t.first; + const auto & id = t.second; + + // Count all non-normal tokens in the vocab while iterating + if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) { + special_tokens_count_by_type++; + } + + // Skip single character tokens + if (token.length() > 1) { + bool is_tokenizable = false; + + // Split token string representation in two, in all possible ways + // and check if both halves can be matched to a valid token + for (unsigned i = 1; i < token.length();) { + const auto left = token.substr(0, i); + const auto right = token.substr(i); + + // check if we didnt partition in the middle of a utf sequence + auto utf = utf8_len(left.at(left.length() - 1)); + + if (utf == 1) { + if (vocab.token_to_id.find(left) != vocab.token_to_id.end() && + vocab.token_to_id.find(right) != vocab.token_to_id.end() ) { + is_tokenizable = true; + break; + } + i++; + } else { + // skip over the rest of multibyte utf sequence + i += utf - 1; + } + } + + if (!is_tokenizable) { + // Some tokens are multibyte, but they are utf sequences with equivalent text length of 1 + // it's faster to re-filter them here, since there are way less candidates now + + // Calculate a total "utf" length of a token string representation + size_t utf8_str_len = 0; + for (unsigned i = 0; i < token.length();) { + utf8_str_len++; + i += utf8_len(token.at(i)); + } + + // And skip the ones which are one character + if (utf8_str_len > 1) { + // At this point what we have left are special tokens only + vocab.special_tokens_cache[token] = id; + + // Count manually found special tokens + special_tokens_count_from_verification++; + + // If this manually found special token is not marked as such, flag a mismatch + if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) { + special_tokens_definition_mismatch = true; + } + } + } + } + } + + if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) { + LLAMA_LOG_WARN("%s: mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n", + __func__, + special_tokens_count_from_verification, vocab.id_to_token.size(), + special_tokens_count_by_type, vocab.id_to_token.size() + ); + } else { + LLAMA_LOG_INFO("%s: special tokens definition check successful ( %u/%zu ).\n", + __func__, + special_tokens_count_from_verification, vocab.id_to_token.size() + ); + } + } +} + +static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { + const auto & hparams = model.hparams; + const auto & vocab = model.vocab; + + const auto rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train); + + // hparams + LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); + LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix + LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); + LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size()); + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); + LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim + LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); + LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); + LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); + LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); + LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); + LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); + LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); + LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); + LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx); + LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); + LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); + if (ml.n_bytes < GiB) { + LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); + } else { + LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); + } + + // general kv + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str()); + + // special tokens + if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } + if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } + if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } + if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } + if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } + if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } +} + +static void llm_load_tensors( + llama_model_loader & ml, + llama_model & model, + int n_gpu_layers, + int main_gpu, + const float * tensor_split, + bool use_mlock, + llama_progress_callback progress_callback, + void * progress_callback_user_data) { + model.t_start_us = ggml_time_us(); + + auto & ctx = model.ctx; + auto & hparams = model.hparams; + + model.n_gpu_layers = n_gpu_layers; + + size_t ctx_size; + size_t mmapped_size; + + ml.calc_sizes(ctx_size, mmapped_size); + + LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MiB\n", __func__, ctx_size/1024.0/1024.0); + + // create the ggml context + { + model.buf.resize(ctx_size); + if (use_mlock) { + model.mlock_buf.init (model.buf.data); + model.mlock_buf.grow_to(model.buf.size); + } + + struct ggml_init_params params = { + /*.mem_size =*/ model.buf.size, + /*.mem_buffer =*/ model.buf.data, + /*.no_alloc =*/ ml.use_mmap, + }; + + model.ctx = ggml_init(params); + if (!model.ctx) { + throw std::runtime_error(format("ggml_init() failed")); + } + } + + (void) main_gpu; + + enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU; + enum ggml_backend_type llama_backend_offload_split = GGML_BACKEND_CPU; + +#ifdef GGML_USE_CUBLAS + if (ggml_cublas_loaded()) { + LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__); + ggml_cuda_set_main_device(main_gpu); + + llama_backend_offload = GGML_BACKEND_GPU; + llama_backend_offload_split = GGML_BACKEND_GPU_SPLIT; + } +#elif defined(GGML_USE_CLBLAST) + LLAMA_LOG_INFO("%s: using OpenCL for GPU acceleration\n", __func__); + llama_backend_offload = GGML_BACKEND_GPU; + llama_backend_offload_split = GGML_BACKEND_GPU; +#endif + + // prepare memory for the weights + size_t vram_weights = 0; + { + const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + const int64_t n_layer = hparams.n_layer; + const int64_t n_vocab = hparams.n_vocab; + + const auto tn = LLM_TN(model.arch); + switch (model.arch) { + case LLM_ARCH_LLAMA: + case LLM_ARCH_REFACT: + { + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + + // output + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + backend_norm = llama_backend_offload; + backend_output = llama_backend_offload_split; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + + layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); + layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + // optional bias tensors + layer.bq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, backend, false); + layer.bk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, backend, false); + layer.bv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, backend, false); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend, false); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + + layer.ffn_gate_inp = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, backend, false); + + if (layer.ffn_gate_inp == nullptr) { + GGML_ASSERT(hparams.n_expert == 0); + GGML_ASSERT(hparams.n_expert_used == 0); + + layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + } else { + GGML_ASSERT(hparams.n_expert > 0); + GGML_ASSERT(hparams.n_expert_used > 0); + + // MoE branch + for (uint32_t x = 0; x < hparams.n_expert; ++x) { + layer.ffn_gate_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd, n_ff}, backend_split); + layer.ffn_down_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd}, backend_split); + layer.ffn_up_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), {n_embd, n_ff}, backend_split); + } + } + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + + (layer.bq ? ggml_nbytes(layer.bq) : 0) + + (layer.bk ? ggml_nbytes(layer.bk) : 0) + + (layer.bv ? ggml_nbytes(layer.bv) : 0) + + (layer.bo ? ggml_nbytes(layer.bo) : 0) + + ggml_nbytes(layer.ffn_norm); + + if (layer.ffn_gate_inp == nullptr) { + vram_weights += + ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); + } else { + vram_weights += ggml_nbytes(layer.ffn_gate_inp); + for (uint32_t x = 0; x < hparams.n_expert; ++x) { + vram_weights += + ggml_nbytes(layer.ffn_gate_exp[x]) + ggml_nbytes(layer.ffn_down_exp[x]) + ggml_nbytes(layer.ffn_up_exp[x]); + } + } + } + } + } break; + case LLM_ARCH_BAICHUAN: + { + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + backend_norm = llama_backend_offload; + backend_output = llama_backend_offload_split; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + + layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); + layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + + layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); + } + } + } break; + case LLM_ARCH_FALCON: + { + // TODO: CPU-only for now + + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + + // output + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + backend_norm = llama_backend_offload; + backend_output = llama_backend_offload_split; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + vram_weights += ggml_nbytes(model.output_norm_b); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + + if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) { + layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend); + layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, backend); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(layer.attn_norm_2); + vram_weights += ggml_nbytes(layer.attn_norm_2_b); + } + } + + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + + ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.wo) + + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); + } + } + } break; + case LLM_ARCH_STARCODER: + { + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.pos_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU); + + // output + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + backend_norm = llama_backend_offload; + backend_output = llama_backend_offload_split; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + vram_weights += ggml_nbytes(model.output_norm_b); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend); + + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); + + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend); + + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + + ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + + ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) + + ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm_b) + + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_down_b) + + ggml_nbytes(layer.ffn_up) + ggml_nbytes(layer.ffn_up_b); + } + } + } break; + case LLM_ARCH_PERSIMMON: + { + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + backend_norm = llama_backend_offload; + backend_output = llama_backend_offload_split; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + vram_weights += ggml_nbytes(model.output_norm_b); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + const int i_gpu_start = n_layer - n_gpu_layers; + model.layers.resize(n_layer); + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; + auto & layer = model.layers[i]; + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); + layer.attn_q_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {64}, backend); + layer.attn_q_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {64}, backend); + layer.attn_k_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {64}, backend); + layer.attn_k_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}, backend); + } + } break; + case LLM_ARCH_BLOOM: + { + // TODO: CPU-only for now + + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.tok_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, GGML_BACKEND_CPU); + model.tok_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, GGML_BACKEND_CPU); + + // output + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + backend_norm = llama_backend_offload; + backend_output = llama_backend_offload_split; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + vram_weights += ggml_nbytes(model.output_norm_b); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend); + + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); + + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend); + + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + + ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + + ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) + + ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm_b) + + ggml_nbytes(layer.ffn_up) + ggml_nbytes(layer.ffn_up_b) + + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_down_b); + } + } + } break; + case LLM_ARCH_MPT: + { + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + + // output + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + backend_norm = llama_backend_offload; + backend_output = llama_backend_offload_split; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + + ggml_nbytes(layer.wqkv) + + ggml_nbytes(layer.wo) + + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.ffn_down) + + ggml_nbytes(layer.ffn_up); + } + } + } break; + case LLM_ARCH_STABLELM: + { + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + + // output + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + backend_norm = llama_backend_offload; + backend_output = llama_backend_offload_split; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + /* + llama_model_loader: - tensor 4: blk.0.attn_output.weight f16 [ 2560, 2560, 1, 1 ] + */ + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + + layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); + layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); + + layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); + } + } + } break; + case LLM_ARCH_QWEN: + { + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + backend_norm = llama_backend_offload; + backend_output = llama_backend_offload_split; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff / 2; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd * 3}, backend_split); + layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd * 3}, backend); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + + layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_gate) + + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); + } + } + } break; + + default: + throw std::runtime_error("unknown architecture"); + } + } + + ml.done_getting_tensors(); + + // print memory requirements + { + // this is the total memory required to run the inference + size_t mem_required = + ctx_size + + mmapped_size - vram_weights; // weights in VRAM not in memory + + LLAMA_LOG_INFO("%s: mem required = %7.2f MiB\n", __func__, mem_required / 1024.0 / 1024.0); + +#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) + const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); + + LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); + if (n_gpu_layers > (int) hparams.n_layer) { + LLAMA_LOG_INFO("%s: offloading non-repeating layers to GPU\n", __func__); + } + +#ifdef GGML_USE_CUBLAS + const int max_backend_supported_layers = hparams.n_layer + 1; + const int max_offloadable_layers = hparams.n_layer + 1; +#elif GGML_USE_CLBLAST + const int max_backend_supported_layers = hparams.n_layer + 1; + const int max_offloadable_layers = hparams.n_layer + 1; +#endif // GGML_USE_CUBLAS + + LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); + LLAMA_LOG_INFO("%s: VRAM used: %.2f MiB\n", __func__, vram_weights / 1024.0 / 1024.0); +#else + (void) n_gpu_layers; +#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) + } + + // populate `tensors_by_name` + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * cur = ggml_get_tensor(ctx, ml.get_tensor_name(i)); + model.tensors_by_name.emplace_back(ggml_get_name(cur), cur); + } + + (void) tensor_split; +#ifdef GGML_USE_CUBLAS + { + ggml_cuda_set_tensor_split(tensor_split); + } +#endif + + ml.load_all_data(ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL); + + if (progress_callback) { + progress_callback(1.0f, progress_callback_user_data); + } + + model.mapping = std::move(ml.mapping); + + // loading time will be recalculate after the first eval, so + // we take page faults deferred by mmap() into consideration + model.t_load_us = ggml_time_us() - model.t_start_us; +} + +static bool llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) { + try { + llama_model_loader ml(fname, params.use_mmap, params.kv_overrides); + + model.hparams.vocab_only = params.vocab_only; + + llm_load_arch (ml, model); + llm_load_hparams(ml, model); + llm_load_vocab (ml, model); + + llm_load_print_meta(ml, model); + + if (model.hparams.n_vocab != model.vocab.id_to_token.size()) { + throw std::runtime_error("vocab size mismatch"); + } + + if (params.vocab_only) { + LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); + return true; + } + + llm_load_tensors( + ml, model, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.use_mlock, + params.progress_callback, params.progress_callback_user_data + ); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("error loading model: %s\n", err.what()); + return false; + } + + return true; +} + +// +// llm_build +// + +using llm_build_cb = std::function; + +enum llm_rope_type { + LLM_ROPE, + LLM_ROPE_NEOX, + LLM_ROPE_GLM, +}; + +enum llm_ffn_op_type { + LLM_FFN_SILU, + LLM_FFN_GELU, + LLM_FFN_RELU, + LLM_FFN_RELU_SQR, +}; + +enum llm_ffn_gate_type { + LLM_FFN_SEQ, + LLM_FFN_PAR, // ffn_gate is parallel to ffn_up +}; + +enum llm_norm_type { + LLM_NORM, + LLM_NORM_RMS, +}; + +static struct ggml_tensor * llm_build_inp_embd( + struct ggml_context * ctx, + const llama_hparams & hparams, + const llama_batch & batch, + struct ggml_tensor * tok_embd, + const llm_build_cb & cb) { + const int64_t n_embd = hparams.n_embd; + + struct ggml_tensor * inpL; + + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens); + cb(inp_tokens, "inp_tokens", -1); + + inpL = ggml_get_rows(ctx, tok_embd, inp_tokens); + } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + + inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); + } + + return inpL; +} + +// Persimmon: n_rot = n_embd_head/2 +// Other: n_rot = n_embd_head +static void llm_build_k_shift( + struct ggml_context * ctx, + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache & kv, + struct ggml_cgraph * graph, + llm_rope_type type, + int64_t n_ctx, + int n_rot, + float freq_base, + float freq_scale, + const llm_build_cb & cb) { + const int64_t n_layer = hparams.n_layer; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + const int64_t n_embd_head = hparams.n_embd_head(); + const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx; + const float ext_factor = cparams.yarn_ext_factor; + const float attn_factor = cparams.yarn_attn_factor; + const float beta_fast = cparams.yarn_beta_fast; + const float beta_slow = cparams.yarn_beta_slow; + + GGML_ASSERT(n_embd_head % n_rot == 0); + + struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx); + cb(K_shift, "K_shift", -1); + + int rope_type = 0; + + switch (type) { + case LLM_ROPE: rope_type = 0; break; + case LLM_ROPE_NEOX: rope_type = 2; break; + case LLM_ROPE_GLM: rope_type = 4; break; + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * tmp = + // we rotate only the first n_rot dimensions + ggml_rope_custom_inplace(ctx, + ggml_view_3d(ctx, kv.k_l[il], + n_embd_head, n_head_kv, n_ctx, + ggml_row_size(kv.k_l[il]->type, n_embd_head), + ggml_row_size(kv.k_l[il]->type, n_embd_gqa), + 0), + K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(tmp, "K_shifted", il); + ggml_build_forward_expand(graph, tmp); + } +} + +static void llm_build_kv_store( + struct ggml_context * ctx, + const llama_hparams & hparams, + const llama_kv_cache & kv, + struct ggml_cgraph * graph, + struct ggml_tensor * k_cur, + struct ggml_tensor * v_cur, + int64_t n_ctx, + int32_t n_tokens, + int32_t kv_head, + const llm_build_cb & cb, + int64_t il) { + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + // compute the transposed [n_tokens, n_embd] V matrix + struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens)); + //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed + cb(v_cur_t, "v_cur_t", il); + + struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_gqa, + (ggml_row_size(kv.k_l[il]->type, n_embd_gqa))*kv_head); + cb(k_cache_view, "k_cache_view", il); + + struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv.v_l[il]), + (kv_head)*ggml_element_size(kv.v_l[il])); + cb(v_cache_view, "v_cache_view", il); + + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); +} + +static struct ggml_tensor * llm_build_norm( + struct ggml_context * ctx, + struct ggml_tensor * cur, + const llama_hparams & hparams, + struct ggml_tensor * mw, + struct ggml_tensor * mb, + llm_norm_type type, + const llm_build_cb & cb, + int il) { + switch (type) { + case LLM_NORM: cur = ggml_norm (ctx, cur, hparams.f_norm_eps); break; + case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hparams.f_norm_rms_eps); break; + } + + if (mw || mb) { + cb(cur, "norm", il); + } + + if (mw) { + cur = ggml_mul(ctx, cur, mw); + if (mb) { + cb(cur, "norm_w", il); + } + } + + if (mb) { + cur = ggml_add(ctx, cur, mb); + } + + return cur; +} + +static struct ggml_tensor * llm_build_ffn( + struct ggml_context * ctx, + struct ggml_tensor * cur, + struct ggml_tensor * up, + struct ggml_tensor * up_b, + struct ggml_tensor * gate, + struct ggml_tensor * gate_b, + struct ggml_tensor * down, + struct ggml_tensor * down_b, + llm_ffn_op_type type_op, + llm_ffn_gate_type type_gate, + const llm_build_cb & cb, + int il) { + struct ggml_tensor * tmp = ggml_mul_mat(ctx, up, cur); + cb(tmp, "ffn_up", il); + + if (up_b) { + tmp = ggml_add(ctx, tmp, up_b); + cb(tmp, "ffn_up_b", il); + } + + if (gate) { + switch (type_gate) { + case LLM_FFN_SEQ: + { + cur = ggml_mul_mat(ctx, gate, tmp); + cb(cur, "ffn_gate", il); + } break; + case LLM_FFN_PAR: + { + cur = ggml_mul_mat(ctx, gate, cur); + cb(cur, "ffn_gate", il); + } break; + } + + if (gate_b) { + cur = ggml_add(ctx, cur, gate_b); + cb(cur, "ffn_gate_b", il); + } + } else { + cur = tmp; + } + + switch (type_op) { + case LLM_FFN_SILU: + { + cur = ggml_silu(ctx, cur); + cb(cur, "ffn_silu", il); + } break; + case LLM_FFN_GELU: + { + cur = ggml_gelu(ctx, cur); + cb(cur, "ffn_gelu", il); + } break; + case LLM_FFN_RELU: + { + cur = ggml_relu(ctx, cur); + cb(cur, "ffn_relu", il); + } break; + case LLM_FFN_RELU_SQR: + { + cur = ggml_relu(ctx, cur); + cb(cur, "ffn_relu", il); + + cur = ggml_sqr(ctx, cur); + cb(cur, "ffn_sqr(relu)", il); + } break; + } + + if (type_gate == LLM_FFN_PAR) { + cur = ggml_mul(ctx, cur, tmp); + cb(cur, "ffn_gate_par", il); + } + + cur = ggml_mul_mat(ctx, down, cur); + if (down_b) { + cb(cur, "ffn_down", il); + } + + if (down_b) { + cur = ggml_add(ctx, cur, down_b); + } + + return cur; +} + +// if max_alibi_bias > 0 then apply ALiBi +static struct ggml_tensor * llm_build_kqv( + struct ggml_context * ctx, + const llama_hparams & hparams, + const llama_kv_cache & kv, + struct ggml_tensor * wo, + struct ggml_tensor * wo_b, + struct ggml_tensor * q_cur, + struct ggml_tensor * kq_scale, + struct ggml_tensor * kq_mask, + int64_t n_ctx, + int32_t n_tokens, + int32_t n_kv, + float max_alibi_bias, + const llm_build_cb & cb, + int il) { + const int64_t n_embd = hparams.n_embd; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3); + cb(q, "q", il); + + struct ggml_tensor * k = + ggml_view_3d(ctx, kv.k_l[il], + n_embd_head, n_kv, n_head_kv, + ggml_row_size(kv.k_l[il]->type, n_embd_gqa), + ggml_row_size(kv.k_l[il]->type, n_embd_head), + 0); + cb(k, "k", il); + + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (max_alibi_bias > 0.0f) { + // temporary branch until we figure out how to handle ggml_alibi through ggml_add + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); + + if (max_alibi_bias > 0.0f) { + // TODO: n_head or n_head_kv + // TODO: K-shift is likely not working + // TODO: change to ggml_add + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); + } + + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); + + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head))); + cb(kq, "kq_soft_max_ext", il); + } + + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_kv, n_embd_head, n_head_kv, + ggml_element_size(kv.v_l[il])*n_ctx, + ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head, + 0); + cb(v, "v", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd, n_tokens); + cb(cur, "kqv_merged_cont", il); + + cur = ggml_mul_mat(ctx, wo, cur); + if (wo_b) { + cb(cur, "kqv_wo", il); + } + + if (wo_b) { + cur = ggml_add(ctx, cur, wo_b); + } + + return cur; +} + +struct llm_build_context { + const llama_model & model; + const llama_hparams & hparams; + const llama_cparams & cparams; + const llama_batch & batch; + const llama_kv_cache & kv_self; + + const int64_t n_embd; + const int64_t n_layer; + const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) + const int64_t n_head; + const int64_t n_head_kv; + const int64_t n_embd_head; + const int64_t n_embd_gqa; + const int64_t n_expert; + const int64_t n_expert_used; + + const float freq_base; + const float freq_scale; + const float ext_factor; + const float attn_factor; + const float beta_fast; + const float beta_slow; + const float norm_eps; + const float norm_rms_eps; + + const int32_t n_tokens; + const int32_t n_kv; // size of KV cache to consider (n_kv <= n_ctx) + const int32_t kv_head; // index of where we store new KV data in the cache + const int32_t n_orig_ctx; + + const bool do_rope_shift; + + const llm_build_cb & cb; + + llama_buffer & buf_compute; + + struct ggml_context * ctx0 = nullptr; + + // TODO: consider making the entire interface noexcept + llm_build_context( + llama_context & lctx, + const llama_batch & batch, + const llm_build_cb & cb, + bool worst_case) : + model (lctx.model), + hparams (model.hparams), + cparams (lctx.cparams), + batch (batch), + kv_self (lctx.kv_self), + n_embd (hparams.n_embd), + n_layer (hparams.n_layer), + n_ctx (cparams.n_ctx), + n_head (hparams.n_head), + n_head_kv (hparams.n_head_kv), + n_embd_head (hparams.n_embd_head()), + n_embd_gqa (hparams.n_embd_gqa()), + n_expert (hparams.n_expert), + n_expert_used (hparams.n_expert_used), + freq_base (cparams.rope_freq_base), + freq_scale (cparams.rope_freq_scale), + ext_factor (cparams.yarn_ext_factor), + attn_factor (cparams.yarn_attn_factor), + beta_fast (cparams.yarn_beta_fast), + beta_slow (cparams.yarn_beta_slow), + norm_eps (hparams.f_norm_eps), + norm_rms_eps (hparams.f_norm_rms_eps), + n_tokens (batch.n_tokens), + n_kv (worst_case ? n_ctx : kv_self.n), + kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), + n_orig_ctx (cparams.n_yarn_orig_ctx), + do_rope_shift (worst_case || kv_self.has_shift), + cb (cb), + buf_compute (lctx.buf_compute) { + GGML_ASSERT(!!kv_self.ctx); + + // all initializations should be done in init() + } + + void init() { + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ true, + }; + + ctx0 = ggml_init(params); + } + + void free() { + if (ctx0) { + ggml_free(ctx0); + ctx0 = nullptr; + } + } + + struct ggml_cgraph * build_llama() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, model.layers[il].bo, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] + cb(logits, "ffn_moe_logits", il); + + ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] + cb(probs, "ffn_moe_probs", il); + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok] + cb(selected_experts->src[0], "ffn_moe_argsort", il); + + ggml_tensor * weights = ggml_get_rows(ctx0, + ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); + cb(weights, "ffn_moe_weights", il); + + weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok] + + ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); + cb(weights_sum, "ffn_moe_weights_sum", il); + + weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok] + cb(weights, "ffn_moe_weights_norm", il); + + // compute expert outputs + ggml_tensor * moe_out = nullptr; + + for (int i = 0; i < n_expert_used; ++i) { + ggml_tensor * cur_expert; + + ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur); + cb(cur_up, "ffn_moe_up", il); + + ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur); + cb(cur_gate, "ffn_moe_gate", il); + + cur_gate = ggml_silu(ctx0, cur_gate); + cb(cur_gate, "ffn_moe_silu", il); + + cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd] + cb(cur_expert, "ffn_moe_gate_par", il); + + cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd] + cb(cur_expert, "ffn_moe_down", il); + + cur_expert = ggml_mul(ctx0, cur_expert, + ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); + cb(cur_expert, "ffn_moe_weighted", il); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ggml_add(ctx0, moe_out, cur_expert); + cb(moe_out, "ffn_moe_out", il); + } + } + + cur = moe_out; + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_baichuan() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + switch (model.type) { + case MODEL_7B: + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + break; + case MODEL_13B: + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens); + break; + default: + GGML_ASSERT(false); + } + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + // apply ALiBi for 13B model + const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f; + + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_falcon() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * attn_norm; + + attn_norm = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(attn_norm, "attn_norm", il); + + // self-attention + { + if (model.layers[il].attn_norm_2) { + // Falcon-40B + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm_2, + model.layers[il].attn_norm_2_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm_2", il); + } else { + cur = attn_norm; + } + + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + // using mode = 2 for neox mode + Qcur = ggml_rope_custom( + ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = cur; + + // feed forward + { + cur = llm_build_ffn(ctx0, attn_norm, // !! use the attn norm, not the result + model.layers[il].ffn_up, NULL, + NULL, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + // norm + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_starcoder() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + struct ggml_tensor * cur; + struct ggml_tensor * pos; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); + cb(pos, "pos_embd", -1); + + inpL = ggml_add(ctx0, inpL, pos); + cb(inpL, "inpL", -1); + + for (int il = 0; il < n_layer; ++il) { + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, model.layers[il].bo, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + cb(cur, "kqv_out", il); + } + + // add the input + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + } + + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_persimmon() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_rot = n_embd_head / 2; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "imp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * residual = inpL; + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + // split qkv + GGML_ASSERT(n_head_kv == n_head); + + struct ggml_tensor * tmpqkv = ggml_reshape_4d(ctx0, cur, n_embd_head, 3, n_head, n_tokens); + cb(tmpqkv, "tmpqkv", il); + + struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2)); + cb(tmpqkv_perm, "tmpqkv", il); + + struct ggml_tensor * tmpq = ggml_view_3d( + ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, + ggml_element_size(tmpqkv_perm) * n_embd_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + 0 + ); + cb(tmpq, "tmpq", il); + + struct ggml_tensor * tmpk = ggml_view_3d( + ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, + ggml_element_size(tmpqkv_perm) * n_embd_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens + ); + cb(tmpk, "tmpk", il); + + // Q/K Layernorm + tmpq = llm_build_norm(ctx0, tmpq, hparams, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, cb, il); + cb(tmpq, "tmpq", il); + + tmpk = llm_build_norm(ctx0, tmpk, hparams, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, cb, il); + cb(tmpk, "tmpk", il); + + // RoPE the first n_rot of q/k, pass the other half, and concat. + struct ggml_tensor * qrot = ggml_view_3d( + ctx0, tmpq, n_rot, n_head, n_tokens, + ggml_element_size(tmpq) * n_embd_head, + ggml_element_size(tmpq) * n_embd_head * n_head, + 0 + ); + cb(qrot, "qrot", il); + + struct ggml_tensor * krot = ggml_view_3d( + ctx0, tmpk, n_rot, n_head, n_tokens, + ggml_element_size(tmpk) * n_embd_head, + ggml_element_size(tmpk) * n_embd_head * n_head, + 0 + ); + cb(krot, "krot", il); + + // get the second half of tmpq, e.g tmpq[n_rot:, :, :] + struct ggml_tensor * qpass = ggml_view_3d( + ctx0, tmpq, n_rot, n_head, n_tokens, + ggml_element_size(tmpq) * n_embd_head, + ggml_element_size(tmpq) * n_embd_head * n_head, + ggml_element_size(tmpq) * n_rot + ); + cb(qpass, "qpass", il); + + struct ggml_tensor * kpass = ggml_view_3d( + ctx0, tmpk, n_rot, n_head, n_tokens, + ggml_element_size(tmpk) * n_embd_head, + ggml_element_size(tmpk) * n_embd_head * n_head, + ggml_element_size(tmpk) * n_rot + ); + cb(kpass, "kpass", il); + + struct ggml_tensor * qrotated = ggml_rope_custom( + ctx0, qrot, inp_pos, n_rot, 2, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(qrotated, "qrotated", il); + + struct ggml_tensor * krotated = ggml_rope_custom( + ctx0, krot, inp_pos, n_rot, 2, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(krotated, "krotated", il); + + // ggml currently only supports concatenation on dim=2 + // so we need to permute qrot, qpass, concat, then permute back. + qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3)); + cb(qrotated, "qrotated", il); + + krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3)); + cb(krotated, "krotated", il); + + qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3)); + cb(qpass, "qpass", il); + + kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3)); + cb(kpass, "kpass", il); + + struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 2, 1, 0, 3)); + cb(Q, "Q", il); + + Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3)); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_view_3d( + ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, + ggml_element_size(tmpqkv_perm) * n_embd_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2 + ); + cb(Vcur, "Vcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + // TODO: not tested, could be broken + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, model.layers[il].bo, + Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_refact() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + cb(Kcur, "Kcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_bloom() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + inpL = llm_build_norm(ctx0, inpL, hparams, + model.tok_norm, + model.tok_norm_b, + LLM_NORM, cb, -1); + cb(inpL, "inp_norm", -1); + + for (int il = 0; il < n_layer; ++il) { + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, model.layers[il].bo, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il); + cb(cur, "kqv_out", il); + } + + // Add the input + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + } + + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_mpt() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * attn_norm; + + attn_norm = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + NULL, + LLM_NORM, cb, il); + cb(attn_norm, "attn_norm", il); + + // self-attention + { + cur = attn_norm; + + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + if (hparams.f_clamp_kqv > 0.0f) { + cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(cur, "wqkv_clamped", il); + } + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, cb, il); + cb(cur, "kqv_out", il); + } + + // Add the input + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // feed forward + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + NULL, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + NULL, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, + NULL, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_stablelm() { + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_qwen() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos= ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale= ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask= ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + // using mode = 2 for neox mode + Qcur = ggml_rope_custom( + ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward forward + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } +}; + +// +// tensor offloading helpers +// +// TODO: will be removed with backend v2 + +enum llm_offload_func_e { + OFFLOAD_FUNC_NOP, + OFFLOAD_FUNC, + OFFLOAD_FUNC_FRC, // force offload + OFFLOAD_FUNC_KQV, + OFFLOAD_FUNC_NR, + OFFLOAD_FUNC_EMB, + OFFLOAD_FUNC_OUT, +}; + +// TODO: will be removed with backend v2 +struct llm_offload_trie { + struct node { + ~node() { + for (int i = 0; i < 256; ++i) { + if (children[i]) { + delete children[i]; + } + } + } + + node * children[256] = { nullptr }; + llm_offload_func_e func = OFFLOAD_FUNC_NOP; + }; + + llm_offload_trie() { + root = new node; + } + + llm_offload_trie(const std::unordered_map & map) { + root = new node; + + for (const auto & kv : map) { + add(kv.first, kv.second); + } + } + + ~llm_offload_trie() { + delete root; + } + + void add(const char * name, llm_offload_func_e func) { + node * cur = root; + + for (int i = 0; ; ++i) { + const uint8_t c = name[i]; + + if (!c) { + break; + } + + if (!cur->children[c]) { + cur->children[c] = new node; + } + + cur = cur->children[c]; + } + + cur->func = func; + } + + llm_offload_func_e find(const char * name) const { + const node * cur = root; + + for (int i = 0; ; ++i) { + const uint8_t c = name[i]; + + if (!c) { + break; + } + + if (!cur->children[c]) { + return OFFLOAD_FUNC_NOP; + } + + cur = cur->children[c]; + } + + return cur->func; + } + + node * root = nullptr; +}; + +// TODO: will be removed with backend v2 +static const std::unordered_map k_offload_map = { + //{ "inp_tokens", OFFLOAD_FUNC_NR }, // TODO: missing K-quants get_rows kernel + //{ "inp_embd", OFFLOAD_FUNC_NR }, // TODO: missing K-quants get_rows kernel + { "pos_embd", OFFLOAD_FUNC_NR }, + + { "inp_pos", OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope) + { "KQ_scale", OFFLOAD_FUNC_FRC }, + { "KQ_mask", OFFLOAD_FUNC_FRC }, + { "K_shift", OFFLOAD_FUNC_FRC }, + + { "K_shifted", OFFLOAD_FUNC }, + + { "inp_norm", OFFLOAD_FUNC_NR }, + { "inp_norm_w", OFFLOAD_FUNC_NR }, + { "inp_norm_wb", OFFLOAD_FUNC_NR }, + + { "norm", OFFLOAD_FUNC }, + { "norm_w", OFFLOAD_FUNC }, + { "norm_wb", OFFLOAD_FUNC }, + + { "attn_norm", OFFLOAD_FUNC }, + { "attn_norm_2", OFFLOAD_FUNC }, + + { "wqkv", OFFLOAD_FUNC_KQV }, + { "bqkv", OFFLOAD_FUNC_KQV }, + { "wqkv_clamped", OFFLOAD_FUNC_KQV }, + + { "tmpk", OFFLOAD_FUNC_KQV }, + { "tmpq", OFFLOAD_FUNC_KQV }, + { "tmpv", OFFLOAD_FUNC_KQV }, + { "Kcur", OFFLOAD_FUNC_KQV }, + { "Qcur", OFFLOAD_FUNC_KQV }, + { "Vcur", OFFLOAD_FUNC_KQV }, + + { "krot", OFFLOAD_FUNC_KQV }, + { "qrot", OFFLOAD_FUNC_KQV }, + { "kpass", OFFLOAD_FUNC_KQV }, + { "qpass", OFFLOAD_FUNC_KQV }, + { "krotated", OFFLOAD_FUNC_KQV }, + { "qrotated", OFFLOAD_FUNC_KQV }, + + { "q", OFFLOAD_FUNC_KQV }, + { "k", OFFLOAD_FUNC_KQV }, + { "kq", OFFLOAD_FUNC_KQV }, + { "kq_scaled", OFFLOAD_FUNC_KQV }, + { "kq_scaled_alibi", OFFLOAD_FUNC_KQV }, + { "kq_masked", OFFLOAD_FUNC_KQV }, + { "kq_soft_max", OFFLOAD_FUNC_KQV }, + { "kq_soft_max_ext", OFFLOAD_FUNC_KQV }, + { "v", OFFLOAD_FUNC_KQV }, + { "kqv", OFFLOAD_FUNC_KQV }, + { "kqv_merged", OFFLOAD_FUNC_KQV }, + { "kqv_merged_cont", OFFLOAD_FUNC_KQV }, + { "kqv_wo", OFFLOAD_FUNC_KQV }, + { "kqv_out", OFFLOAD_FUNC_KQV }, + + { "ffn_inp", OFFLOAD_FUNC }, + { "ffn_norm", OFFLOAD_FUNC }, + + { "ffn_up", OFFLOAD_FUNC }, + { "ffn_up_b", OFFLOAD_FUNC }, + { "ffn_gate", OFFLOAD_FUNC }, + { "ffn_gate_b", OFFLOAD_FUNC }, + { "ffn_gate_par", OFFLOAD_FUNC }, + { "ffn_down", OFFLOAD_FUNC }, + { "ffn_down_b", OFFLOAD_FUNC }, + { "ffn_out", OFFLOAD_FUNC }, + + { "ffn_silu", OFFLOAD_FUNC }, + { "ffn_gelu", OFFLOAD_FUNC }, + { "ffn_relu", OFFLOAD_FUNC }, + { "ffn_sqr(relu)", OFFLOAD_FUNC }, + + { "ffn_moe_logits", OFFLOAD_FUNC }, + { "ffn_moe_probs", OFFLOAD_FUNC }, + { "ffn_moe_argsort", OFFLOAD_FUNC }, + { "ffn_moe_weights", OFFLOAD_FUNC }, + { "ffn_moe_weights_sum", OFFLOAD_FUNC }, + { "ffn_moe_weights_norm", OFFLOAD_FUNC }, + { "ffn_moe_weighted", OFFLOAD_FUNC }, + { "ffn_moe_up", OFFLOAD_FUNC }, + { "ffn_moe_gate", OFFLOAD_FUNC }, + { "ffn_moe_silu", OFFLOAD_FUNC }, + { "ffn_moe_gate_par", OFFLOAD_FUNC }, + { "ffn_moe_down", OFFLOAD_FUNC }, + { "ffn_moe_out", OFFLOAD_FUNC }, + + { "l_out", OFFLOAD_FUNC }, + + { "result_norm", OFFLOAD_FUNC_EMB }, + { "result_output", OFFLOAD_FUNC_OUT }, +}; + +static llm_offload_trie k_offload_func_trie(k_offload_map); + +static struct ggml_cgraph * llama_build_graph( + llama_context & lctx, + const llama_batch & batch) { + const auto & model = lctx.model; + + // check if we should build the worst-case graph (for memory measurement) + const bool worst_case = ggml_allocr_is_measure(lctx.alloc); + + // keep track of the input that has already been allocated + bool alloc_inp_tokens = false; + bool alloc_inp_embd = false; + bool alloc_inp_pos = false; + bool alloc_inp_KQ_scale = false; + bool alloc_inp_KQ_mask = false; + bool alloc_inp_K_shift = false; + +#ifdef GGML_USE_CUBLAS + const bool do_offload = true; +#else + const bool do_offload = true; // TODO: set to false after finishing refactoring +#endif + + int n_non_view = 0; // number of non-view tensors that have been processed by the callback + + // this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) + // TODO: will be removed with backend v2 + llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) { + if (il >= 0) { + ggml_format_name(cur, "%s-%d", name, il); + } else { + ggml_set_name(cur, name); + } + + // + // allocate input tensors and set input data + // + // TODO: will be removed with backend v2 + + if (!alloc_inp_tokens && strcmp(name, "inp_tokens") == 0) { + ggml_allocr_alloc(lctx.alloc, cur); + + if (!ggml_allocr_is_measure(lctx.alloc) && batch.token) { + const int64_t n_tokens = cur->ne[0]; + + memcpy(cur->data, batch.token, n_tokens*ggml_element_size(cur)); + } + + alloc_inp_tokens = true; + } + + if (!alloc_inp_embd && strcmp(name, "inp_embd") == 0) { + ggml_allocr_alloc(lctx.alloc, cur); + + if (!ggml_allocr_is_measure(lctx.alloc) && batch.embd) { + const int64_t n_embd = cur->ne[0]; + const int64_t n_tokens = cur->ne[1]; + + memcpy(cur->data, batch.embd, n_tokens*n_embd*ggml_element_size(cur)); + } + + alloc_inp_embd = true; + } + + if (!alloc_inp_pos && strcmp(name, "inp_pos") == 0) { + ggml_allocr_alloc(lctx.alloc, cur); + + if (!ggml_allocr_is_measure(lctx.alloc) && batch.pos) { + const int64_t n_tokens = cur->ne[0]; + + int32_t * data = (int32_t *) cur->data; + + for (int i = 0; i < n_tokens; ++i) { + data[i] = batch.pos[i]; + } + } + + alloc_inp_pos = true; + } + + if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) { + ggml_allocr_alloc(lctx.alloc, cur); + + if (!ggml_allocr_is_measure(lctx.alloc)) { + const int64_t n_embd_head = model.hparams.n_embd_head(); + ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head))); + } + + alloc_inp_KQ_scale = true; + } + + if (!alloc_inp_KQ_mask && strcmp(name, "KQ_mask") == 0) { + ggml_allocr_alloc(lctx.alloc, cur); + + if (!ggml_allocr_is_measure(lctx.alloc)) { + const int64_t n_kv = cur->ne[0]; + const int64_t n_tokens = cur->ne[1]; + + float * data = (float *) cur->data; + memset(data, 0, ggml_nbytes(cur)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; + + for (int i = 0; i < n_kv; ++i) { + if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + alloc_inp_KQ_mask = true; + } + + if (!alloc_inp_K_shift && strcmp(name, "K_shift") == 0) { + ggml_allocr_alloc(lctx.alloc, cur); + + if (!ggml_allocr_is_measure(lctx.alloc)) { + const int64_t n_ctx = cur->ne[0]; + + int32_t * data = (int32_t *) cur->data; + + for (int i = 0; i < n_ctx; ++i) { + data[i] = lctx.kv_self.cells[i].delta; + } + } + + alloc_inp_K_shift = true; + } + + // view tensors are not processed further + if (cur->view_src != nullptr) { + return; + } + + if (cur->op != GGML_OP_NONE) { + n_non_view++; + } + + // + // offload layers + // + // TODO: will be removed with backend v2 + +//#define LLAMA_OFFLOAD_DEBUG + + if (!do_offload) { + return; + } + + const int n_layer = model.hparams.n_layer; + + const int n_gpu_layers = model.n_gpu_layers; + const int i_gpu_start = n_layer - n_gpu_layers; + + // should we offload the final norm? yes if we are not computing embeddings + const bool offload_emb = lctx.embedding.empty(); + + static const std::unordered_map> k_offload_func_name = { + { OFFLOAD_FUNC_NOP, "CPU" }, + { OFFLOAD_FUNC_OUT, "CPU" }, +#ifdef GGML_USE_CUBLAS + { OFFLOAD_FUNC, "GPU (CUDA)" }, + { OFFLOAD_FUNC_FRC, "GPU (CUDA) FRC" }, + { OFFLOAD_FUNC_KQV, "GPU (CUDA) KQV" }, + { OFFLOAD_FUNC_NR, "GPU (CUDA) NR" }, + { OFFLOAD_FUNC_EMB, "GPU (CUDA) EMB" }, +#else + { OFFLOAD_FUNC, "CPU" }, + { OFFLOAD_FUNC_FRC, "CPU" }, + { OFFLOAD_FUNC_KQV, "CPU" }, + { OFFLOAD_FUNC_NR, "CPU" }, + { OFFLOAD_FUNC_EMB, "CPU" }, +#endif // GGML_USE_CUBLAS + }; + + // check the global map for what offload function to use for this tensor + llm_offload_func_e func_e = k_offload_func_trie.find(name); + + if (func_e == OFFLOAD_FUNC_NOP) { +#ifdef LLAMA_OFFLOAD_DEBUG + // if a tensor hasn't been offloaded, we warn the user + if (worst_case) { + LLAMA_LOG_WARN("%s: %32s: not offloaded (ref: %s)\n", __func__, + cur->name, "https://github.com/ggerganov/llama.cpp/pull/3837"); + } +#endif + + return; + } + + // count the number of layers and respect the provided n_gpu_layers + switch (func_e) { + case OFFLOAD_FUNC_NOP: + case OFFLOAD_FUNC_OUT: + break; + case OFFLOAD_FUNC: + if (n_gpu_layers < n_layer) { + if (il < i_gpu_start) { + func_e = OFFLOAD_FUNC_NOP; + } + } + break; + case OFFLOAD_FUNC_FRC: + if (!lctx.cparams.offload_kqv) { + func_e = OFFLOAD_FUNC_NOP; + } break; + case OFFLOAD_FUNC_KQV: + if (!lctx.cparams.offload_kqv) { + func_e = OFFLOAD_FUNC_NOP; + } else { + if (n_gpu_layers < n_layer) { + if (il < i_gpu_start) { + func_e = OFFLOAD_FUNC_NOP; + } + } + } + break; + case OFFLOAD_FUNC_NR: + if (n_gpu_layers <= n_layer + 0) { + func_e = OFFLOAD_FUNC_NOP; + } + break; + case OFFLOAD_FUNC_EMB: + if (!offload_emb || n_gpu_layers < n_layer) { + func_e = OFFLOAD_FUNC_NOP; + } + break; + default: GGML_ASSERT(false); + } + + offload_func_t func = ggml_offload_nop; + + // this is needed for compatibility with Metal for example +#ifdef GGML_USE_CUBLAS + static offload_func_t ggml_offload_gpu = ggml_cuda_assign_buffers_no_alloc; +#else + static offload_func_t ggml_offload_gpu = ggml_offload_nop; +#endif + + switch (func_e) { + case OFFLOAD_FUNC_NOP: + case OFFLOAD_FUNC_OUT: func = ggml_offload_nop; break; + case OFFLOAD_FUNC: + case OFFLOAD_FUNC_KQV: + case OFFLOAD_FUNC_FRC: + case OFFLOAD_FUNC_NR: + case OFFLOAD_FUNC_EMB: func = ggml_offload_gpu; break; + default: GGML_ASSERT(false); + } + + // apply offload function to the tensor + func(cur); + +#ifdef LLAMA_OFFLOAD_DEBUG + if (worst_case) { + LLAMA_LOG_INFO("%s: %32s: %s\n", __func__, cur->name, k_offload_func_name.at(func_e).c_str()); + } +#endif + }; + + struct ggml_cgraph * result = NULL; + + struct llm_build_context llm(lctx, batch, cb, worst_case); + + llm.init(); + + switch (model.arch) { + case LLM_ARCH_LLAMA: + { + result = llm.build_llama(); + } break; + case LLM_ARCH_BAICHUAN: + { + result = llm.build_baichuan(); + } break; + case LLM_ARCH_FALCON: + { + result = llm.build_falcon(); + } break; + case LLM_ARCH_STARCODER: + { + result = llm.build_starcoder(); + } break; + case LLM_ARCH_PERSIMMON: + { + result = llm.build_persimmon(); + } break; + case LLM_ARCH_REFACT: + { + result = llm.build_refact(); + } break; + case LLM_ARCH_BLOOM: + { + result = llm.build_bloom(); + } break; + case LLM_ARCH_MPT: + { + result = llm.build_mpt(); + } break; + case LLM_ARCH_STABLELM: + { + result = llm.build_stablelm(); + } break; + case LLM_ARCH_QWEN: + { + result = llm.build_qwen(); + } break; + default: + GGML_ASSERT(false); + } + + llm.free(); + + if (worst_case) { + int n_non_view_total = 0; + + for (int i = 0; i < result->n_nodes; ++i) { + if (result->nodes[i]->view_src == nullptr) { + n_non_view_total++; + } + } + + LLAMA_LOG_INFO("%s: non-view tensors processed: %d/%d\n", __func__, n_non_view, n_non_view_total); + + if (n_non_view != n_non_view_total) { + LLAMA_LOG_WARN("%s: ****************************************************************\n", __func__); + LLAMA_LOG_WARN("%s: not all non-view tensors have been processed with a callback\n", __func__); + LLAMA_LOG_WARN("%s: this can indicate an inefficiency in the graph implementation\n", __func__); + LLAMA_LOG_WARN("%s: build with LLAMA_OFFLOAD_DEBUG for more info\n", __func__); + LLAMA_LOG_WARN("%s: ref: https://github.com/ggerganov/llama.cpp/pull/3837\n", __func__); + LLAMA_LOG_WARN("%s: ****************************************************************\n", __func__); + } + } + + return result; +} + +// decode a batch of tokens by evaluating the transformer +// +// - lctx: llama context +// - batch: batch to evaluate +// +// return 0 on success +// return positive int on warning +// return negative int on error +// +static int llama_decode_internal( + llama_context & lctx, + llama_batch batch) { + const uint32_t n_tokens = batch.n_tokens; + + if (n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); + return -1; + } + + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; + + const auto n_batch = cparams.n_batch; + + GGML_ASSERT(n_tokens <= n_batch); + + int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + + const int64_t t_start_us = ggml_time_us(); + +#ifdef GGML_USE_MPI + // TODO: needs fix after #3228 + GGML_ASSERT(false && "not implemented"); + //ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); +#endif + + GGML_ASSERT(n_threads > 0); + + auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_vocab = hparams.n_vocab; + + // helpers for smoother batch API transition + // after deprecating the llama_eval calls, these will be removed + std::vector pos; + + std::vector n_seq_id; + std::vector seq_id_arr; + std::vector> seq_id; + + if (batch.pos == nullptr) { + pos.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + pos[i] = batch.all_pos_0 + i*batch.all_pos_1; + } + + batch.pos = pos.data(); + } + + if (batch.seq_id == nullptr) { + n_seq_id.resize(n_tokens); + seq_id.resize(n_tokens); + seq_id_arr.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + n_seq_id[i] = 1; + seq_id[i].resize(1); + seq_id[i][0] = batch.all_seq_id; + seq_id_arr[i] = seq_id[i].data(); + } + + batch.n_seq_id = n_seq_id.data(); + batch.seq_id = seq_id_arr.data(); + } + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (kv_self.head > kv_self.used + 2*n_tokens) { + kv_self.head = 0; + } + + if (!llama_kv_cache_find_slot(kv_self, batch)) { + return 1; + } + + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + //kv_self.n = llama_kv_cache_cell_max(kv_self); + + //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); + + ggml_allocr_reset(lctx.alloc); + + ggml_cgraph * gf = llama_build_graph(lctx, batch); + + ggml_allocr_alloc_graph(lctx.alloc, gf); + + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; + + GGML_ASSERT(strcmp(res->name, "result_output") == 0); + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + + +#ifdef GGML_USE_CUBLAS + for (int i = 0; i < gf->n_leafs; i++) { + ggml_tensor * node = gf->leafs[i]; + if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { + ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); + ggml_cuda_copy_to_device(node); + } + } + + for (int i = 0; i < gf->n_nodes; i++) { + ggml_tensor * node = gf->nodes[i]; + if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { + ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); + } + } + + // HACK: ggml-alloc may change the tensor backend when reusing a parent, so force output to be on the CPU here if needed + if (!lctx.embedding.empty()) { + embeddings->backend = GGML_BACKEND_CPU; + } + res->backend = GGML_BACKEND_CPU; +#endif + + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + + // for big prompts, if BLAS is enabled, it is better to use only one thread + // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance + // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well + // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering + // with the BLAS calls. need a better solution + if (n_tokens >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { + n_threads = std::min(4, n_threads); + } + + const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 1; + if (ggml_cpu_has_cublas() && fully_offloaded) { + n_threads = 1; + } + +#if GGML_USE_MPI + const int64_t n_layer = hparams.n_layer; + ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); +#endif + +#ifdef GGML_USE_METAL + if (lctx.ctx_metal) { + ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); + ggml_metal_graph_compute(lctx.ctx_metal, gf); + } else { + ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); + } +#else + ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); +#endif + +#if GGML_USE_MPI + ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); +#endif + + // update the kv ring buffer + { + if (kv_self.has_shift) { + kv_self.has_shift = false; + for (uint32_t i = 0; i < kv_self.size; ++i) { + kv_self.cells[i].delta = 0; + } + } + + kv_self.head += n_tokens; + + // Ensure kv cache head points to a valid index. + if (kv_self.head >= kv_self.size) { + kv_self.head = 0; + } + } + +#ifdef GGML_PERF + // print timing information per ggml operation (for debugging purposes) + // requires GGML_PERF to be defined + ggml_graph_print(gf); +#endif + + // plot the computation graph in dot format (for debugging purposes) + //if (n_past%100 == 0) { + // ggml_graph_dump_dot(gf, NULL, "llama.dot"); + //} + + // extract logits + // TODO: do not compute and extract logits if only embeddings are needed + // need to update the graphs to skip "result_output" + { + auto & logits_out = lctx.logits; + +#ifndef NDEBUG + auto & logits_valid = lctx.logits_valid; + logits_valid.clear(); + logits_valid.resize(n_tokens); + + logits_out.clear(); +#endif + + if (batch.logits) { + logits_out.resize(n_vocab * n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab); +#ifndef NDEBUG + logits_valid[i] = true; +#endif + } + } else if (lctx.logits_all) { + logits_out.resize(n_vocab * n_tokens); + memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); +#ifndef NDEBUG + std::fill(logits_valid.begin(), logits_valid.end(), true); +#endif + } else { + logits_out.resize(n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab); +#ifndef NDEBUG + logits_valid[0] = true; +#endif + } + } + + // extract embeddings + if (!lctx.embedding.empty()) { + auto & embedding_out = lctx.embedding; + + embedding_out.resize(n_embd); + memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(n_tokens - 1)), sizeof(float)*n_embd); + } + + // measure the performance only for the single-token evals + if (n_tokens == 1) { + lctx.t_eval_us += ggml_time_us() - t_start_us; + lctx.n_eval++; + } + else if (n_tokens > 1) { + lctx.t_p_eval_us += ggml_time_us() - t_start_us; + lctx.n_p_eval += n_tokens; + } + + // get a more accurate load time, upon first eval + // TODO: fix this + if (!lctx.has_evaluated_once) { + lctx.t_load_us = ggml_time_us() - lctx.t_start_us; + lctx.has_evaluated_once = true; + } + + return 0; +} + +// +// tokenizer +// + +static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { + return vocab.type; +} + +static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL; +} + +static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN; +} + +static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL; +} + +static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE; +} + +static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED; +} + +static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { + GGML_ASSERT(llama_is_byte_token(vocab, id)); + const auto& token_data = vocab.id_to_token.at(id); + switch (llama_vocab_get_type(vocab)) { + case LLAMA_VOCAB_TYPE_SPM: { + auto buf = token_data.text.substr(3, 2); + return strtol(buf.c_str(), NULL, 16); + } + case LLAMA_VOCAB_TYPE_BPE: { + GGML_ASSERT(false); + return unicode_to_bytes_bpe(token_data.text); + } + default: + GGML_ASSERT(false); + } +} + +static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { + static const char * hex = "0123456789ABCDEF"; + switch (llama_vocab_get_type(vocab)) { + case LLAMA_VOCAB_TYPE_SPM: { + const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 }; + return vocab.token_to_id.at(buf); + } + case LLAMA_VOCAB_TYPE_BPE: { + return vocab.token_to_id.at(bytes_to_unicode_bpe(ch)); + } + default: + GGML_ASSERT(false); + } +} + +static void llama_escape_whitespace(std::string & text) { + replace_all(text, " ", "\xe2\x96\x81"); +} + +static void llama_unescape_whitespace(std::string & word) { + replace_all(word, "\xe2\x96\x81", " "); +} + +struct llm_symbol { + using index = int; + index prev; + index next; + const char * text; + size_t n; +}; + +static_assert(std::is_trivially_copyable::value, "llm_symbol is not trivially copyable"); + +// SPM tokenizer +// original implementation: +// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 + +struct llm_bigram_spm { + struct comparator { + bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) { + return (l.score < r.score) || (l.score == r.score && l.left > r.left); + } + }; + using queue_storage = std::vector; + using queue = std::priority_queue; + llm_symbol::index left; + llm_symbol::index right; + float score; + size_t size; +}; + +struct llm_tokenizer_spm { + llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {} + + void tokenize(const std::string & text, std::vector & output) { + // split string into utf8 chars + int index = 0; + size_t offs = 0; + while (offs < text.size()) { + llm_symbol sym; + size_t len = utf8_len(text[offs]); + sym.text = text.c_str() + offs; + sym.n = std::min(len, text.size() - offs); + offs += sym.n; + sym.prev = index - 1; + sym.next = offs == text.size() ? -1 : index + 1; + index++; + symbols.emplace_back(sym); + } + + // seed the work queue with all possible 2-character tokens. + for (size_t i = 1; i < symbols.size(); ++i) { + try_add_bigram(i - 1, i); + } + + // keep substituting the highest frequency pairs for as long as we can. + while (!work_queue.empty()) { + auto bigram = work_queue.top(); + work_queue.pop(); + + auto & left_sym = symbols[bigram.left]; + auto & right_sym = symbols[bigram.right]; + + // if one of the symbols already got merged, skip it. + if (left_sym.n == 0 || right_sym.n == 0 || + left_sym.n + right_sym.n != bigram.size) { + continue; + } + + // merge the right sym into the left one + left_sym.n += right_sym.n; + right_sym.n = 0; + + //LLAMA_LOG_INFO("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size); + + // remove the right sym from the chain + left_sym.next = right_sym.next; + if (right_sym.next >= 0) { + symbols[right_sym.next].prev = bigram.left; + } + + // find more substitutions + try_add_bigram(left_sym.prev, bigram.left); + try_add_bigram(bigram.left, left_sym.next); + } + + for (int i = 0; i != -1; i = symbols[i].next) { + auto & symbol = symbols[i]; + resegment(symbol, output); + } + } + +private: + void resegment(llm_symbol & symbol, std::vector & output) { + auto text = std::string(symbol.text, symbol.n); + auto token = vocab.token_to_id.find(text); + + // Do we need to support is_unused? + if (token != vocab.token_to_id.end()) { + output.push_back((*token).second); + return; + } + + const auto p = rev_merge.find(text); + + if (p == rev_merge.end()) { + // output any symbols that did not form tokens as bytes. + for (int j = 0; j < (int)symbol.n; ++j) { + llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]); + output.push_back(token_id); + } + return; + } + + resegment(symbols[p->second.first], output); + resegment(symbols[p->second.second], output); + } + + void try_add_bigram(int left, int right) { + if (left == -1 || right == -1) { + return; + } + + const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n); + auto token = vocab.token_to_id.find(text); + + if (token == vocab.token_to_id.end()) { + return; + } + + if (static_cast((*token).second) >= vocab.id_to_token.size()) { + return; + } + + const auto & tok_data = vocab.id_to_token[(*token).second]; + + llm_bigram_spm bigram; + bigram.left = left; + bigram.right = right; + bigram.score = tok_data.score; + bigram.size = text.size(); + + work_queue.push(bigram); + + // Do we need to support is_unused? + rev_merge[text] = std::make_pair(left, right); + } + + const llama_vocab & vocab; + + std::vector symbols; + llm_bigram_spm::queue work_queue; + + std::map> rev_merge; +}; + +// BPE tokenizer +// adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License] +// tried to simplify unicode stuff, so most likely does not work 100% correctly! + +// TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused + +struct llm_bigram_bpe { + struct comparator { + bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const { + return l.rank > r.rank || (l.rank == r.rank && l.left > r.left); + } + }; + + using queue_storage = std::vector; + using queue = std::priority_queue; + llm_symbol::index left; + llm_symbol::index right; + std::string text; + int rank; + size_t size; +}; + +struct llm_tokenizer_bpe { + llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {} + + void tokenize(const std::string & text, std::vector & output) { + int final_prev_index = -1; + auto word_collection = bpe_gpt2_preprocess(text); + + symbols_final.clear(); + + for (auto & word : word_collection) { + work_queue = llm_bigram_bpe::queue(); + symbols.clear(); + + int index = 0; + size_t offset = 0; + + while (offset < word.size()) { + llm_symbol sym; + size_t char_len = std::min(word.size() - offset, (size_t) ::utf8_len(word[offset])); + sym.text = word.c_str() + offset; + sym.n = char_len; + offset += sym.n; + sym.prev = index - 1; + sym.next = offset == word.size() ? -1 : index + 1; + index++; + symbols.emplace_back(sym); + } + for (size_t i = 1; i < symbols.size(); ++i) { + add_new_bigram(i - 1, i); + } + + // build token(s) + while (!work_queue.empty()) { + auto bigram = work_queue.top(); + work_queue.pop(); + + auto & left_symbol = symbols[bigram.left]; + auto & right_symbol = symbols[bigram.right]; + + if (left_symbol.n == 0 || right_symbol.n == 0) { + continue; + } + std::string left_token = std::string(left_symbol.text, left_symbol.n); + std::string right_token = std::string(right_symbol.text, right_symbol.n); + if (left_token + right_token != bigram.text) { + continue; // Skip this bigram if it's outdated + } + + // merge the right sym into the left one + left_symbol.n += right_symbol.n; + right_symbol.n = 0; + + // remove the right sym from the chain + left_symbol.next = right_symbol.next; + if (right_symbol.next >= 0) { + symbols[right_symbol.next].prev = bigram.left; + } + + add_new_bigram(left_symbol.prev, bigram.left); // left side of current symbol + add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol + } + + // add the fnished tokens to the final list keeping correct order for next and prev + for (auto & sym : symbols) { + if (sym.n > 0) { + sym.prev = final_prev_index; + sym.next = -1; + if (final_prev_index != -1) { + symbols_final[final_prev_index].next = symbols_final.size(); + } + symbols_final.emplace_back(sym); + final_prev_index = symbols_final.size() - 1; + } + } + } + + symbols = symbols_final; + + if (!symbols.empty()) { + for (int i = 0; i != -1; i = symbols[i].next) { + auto & symbol = symbols[i]; + if (symbol.n == 0) { + continue; + } + + const std::string str = std::string(symbol.text, symbol.n); + const auto token = vocab.token_to_id.find(str); + + if (token == vocab.token_to_id.end()) { + for (auto j = str.begin(); j != str.end(); ++j) { + std::string byte_str(1, *j); + auto token_multibyte = vocab.token_to_id.find(byte_str); + if (token_multibyte == vocab.token_to_id.end()) { + throw std::runtime_error("ERROR: byte not found in vocab"); + } + output.push_back((*token_multibyte).second); + } + } else { + output.push_back((*token).second); + } + } + } + } + +private: + void add_new_bigram(int left, int right) { + if (left == -1 || right == -1) { + return; + } + + std::string left_token = std::string(symbols[left].text, symbols[left].n); + std::string right_token = std::string(symbols[right].text, symbols[right].n); + + int rank_found = -1; + + rank_found = vocab.find_bpe_rank(left_token, right_token); + + if (rank_found < 0) { + return; + } + + llm_bigram_bpe bigram; + + bigram.left = left; + bigram.right = right; + bigram.text = left_token + right_token; + bigram.size = left_token.size() + right_token.size(); + bigram.rank = rank_found; + + work_queue.push(bigram); + } + + std::vector bpe_gpt2_preprocess(const std::string & text) { + std::vector bpe_words; + std::vector bpe_encoded_words; + + std::string token = ""; + // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ + bool collecting_numeric = false; + bool collecting_letter = false; + bool collecting_special = false; + bool collecting_whitespace_lookahead = false; + bool collecting = false; + + std::vector text_utf; + text_utf.reserve(text.size()); + bpe_words.reserve(text.size()); + bpe_encoded_words.reserve(text.size()); + + auto cps = codepoints_from_utf8(text); + for (size_t i = 0; i < cps.size(); ++i) + text_utf.emplace_back(codepoint_to_utf8(cps[i])); + + for (int i = 0; i < (int)text_utf.size(); i++) { + const std::string & utf_char = text_utf[i]; + bool split_condition = false; + int bytes_remain = text_utf.size() - i; + // forward backward lookups + const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : ""; + const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : ""; + + // handling contractions + if (!split_condition && bytes_remain >= 2) { + // 's|'t|'m|'d + if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) { + split_condition = true; + } + if (split_condition) { + if (token.size()) { + bpe_words.emplace_back(token); // push previous content as token + } + token = utf_char + utf_char_next; + bpe_words.emplace_back(token); + token = ""; + i++; + continue; + } + } + if (!split_condition && bytes_remain >= 3) { + // 're|'ve|'ll + if (utf_char == "\'" && ( + (utf_char_next == "r" && utf_char_next_next == "e") || + (utf_char_next == "v" && utf_char_next_next == "e") || + (utf_char_next == "l" && utf_char_next_next == "l")) + ) { + split_condition = true; + } + if (split_condition) { + // current token + next token can be defined + if (token.size()) { + bpe_words.emplace_back(token); // push previous content as token + } + token = utf_char + utf_char_next + utf_char_next_next; + bpe_words.emplace_back(token); // the contraction + token = ""; + i += 2; + continue; + } + } + + if (!split_condition && !collecting) { + if (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) { + collecting_letter = true; + collecting = true; + } + else if (codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { + collecting_numeric = true; + collecting = true; + } + else if ( + ((codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (codepoint_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) || + (!token.size() && utf_char == " " && codepoint_type(utf_char_next) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) + ) { + collecting_special = true; + collecting = true; + } + else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && codepoint_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) { + collecting_whitespace_lookahead = true; + collecting = true; + } + else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) { + split_condition = true; + } + } + else if (!split_condition && collecting) { + if (collecting_letter && codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER) { + split_condition = true; + } + else if (collecting_numeric && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) { + split_condition = true; + } + else if (collecting_special && (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) { + split_condition = true; + } + else if (collecting_whitespace_lookahead && (codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { + split_condition = true; + } + } + + if (utf_char_next == "") { + split_condition = true; // final + token += utf_char; + } + + if (split_condition) { + if (token.size()) { + bpe_words.emplace_back(token); + } + token = utf_char; + collecting = false; + collecting_letter = false; + collecting_numeric = false; + collecting_special = false; + collecting_whitespace_lookahead = false; + } + else { + token += utf_char; + } + } + + for (std::string & word : bpe_words) { + std::string encoded_token = ""; + for (char & c : word) { + encoded_token += bytes_to_unicode_bpe(c); + } + bpe_encoded_words.emplace_back(encoded_token); + } + + return bpe_encoded_words; + } + + const llama_vocab & vocab; + + std::vector symbols; + std::vector symbols_final; + + llm_bigram_bpe::queue work_queue; +}; + +typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{ + FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, + FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT +} FRAGMENT_BUFFER_VARIANT_TYPE; + +struct fragment_buffer_variant{ + fragment_buffer_variant(llama_vocab::id _token) + : + type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN), + token(_token), + raw_text(_dummy), + offset(0), + length(0){} + fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length) + : + type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT), + token((llama_vocab::id)-1), + raw_text(_raw_text), + offset(_offset), + length(_length){ + GGML_ASSERT( _offset >= 0 ); + GGML_ASSERT( _length >= 1 ); + GGML_ASSERT( offset + length <= raw_text.length() ); + } + + const FRAGMENT_BUFFER_VARIANT_TYPE type; + const llama_vocab::id token; + const std::string _dummy; + const std::string & raw_text; + const uint64_t offset; + const uint64_t length; +}; + +// #define PRETOKENIZERDEBUG + +static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) +{ + // for each special token + for (const auto & st: vocab.special_tokens_cache) { + const auto & special_token = st.first; + const auto & special_id = st.second; + + // for each text fragment + std::forward_list::iterator it = buffer.begin(); + while (it != buffer.end()) { + auto & fragment = (*it); + + // if a fragment is text ( not yet processed ) + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + auto * raw_text = &(fragment.raw_text); + + auto raw_text_base_offset = fragment.offset; + auto raw_text_base_length = fragment.length; + + // loop over the text + while (true) { + // find the first occurrence of a given special token in this fragment + // passing offset argument only limit the "search area" but match coordinates + // are still relative to the source full raw_text + auto match = raw_text->find(special_token, raw_text_base_offset); + + // no occurrences found, stop processing this fragment for a given special token + if (match == std::string::npos) break; + + // check if match is within bounds of offset <-> length + if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break; + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr, "FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); +#endif + auto source = std::distance(buffer.begin(), it); + + // if match is further than base offset + // then we have some text to the left of it + if (match > raw_text_base_offset) { + // left + const int64_t left_reminder_offset = raw_text_base_offset + 0; + const int64_t left_reminder_length = match - raw_text_base_offset; + buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length); + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr, "FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str()); +#endif + it++; + } + + // special token + buffer.emplace_after(it, special_id); + it++; + + // right + if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) { + const int64_t right_reminder_offset = match + special_token.length(); + const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); + buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length); + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr, "FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str()); +#endif + + it++; + + if (source == 0) { + buffer.erase_after(buffer.before_begin()); + } else { + buffer.erase_after(std::next(buffer.begin(), (source-1))); + } + + // repeat for the right side + raw_text_base_offset = right_reminder_offset; + raw_text_base_length = right_reminder_length; + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr, "RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); +#endif + } else { + if (source == 0) { + buffer.erase_after(buffer.before_begin()); + } else { + buffer.erase_after(std::next(buffer.begin(), (source-1))); + } + break; + } + } + } + it++; + } + } +} + +static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special) { + std::vector output; + + // OG tokenizer behavior: + // + // tokenizer.encode('', add_bos=True) returns [1] + // tokenizer.encode('', add_bos=False) returns [] + + if (bos && vocab.special_bos_id != -1) { + output.push_back(vocab.special_bos_id); + } + + if (raw_text.empty()) { + return output; + } + + std::forward_list fragment_buffer; + fragment_buffer.emplace_front( raw_text, 0, raw_text.length() ); + + if (special) tokenizer_st_partition( vocab, fragment_buffer ); + + switch (vocab.type) { + case LLAMA_VOCAB_TYPE_SPM: + { + for (const auto & fragment: fragment_buffer) + { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) + { + // without adding this leading whitespace, we do not get the same results as the original tokenizer + + // TODO: It's likely possible to get rid of this string copy entirely + // by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer + // and passing 'add space prefix' as bool argument + // + auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + if (&fragment == &fragment_buffer.front()) { + raw_text = " " + raw_text; // prefix with space if the first token is not special + } + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + llm_tokenizer_spm tokenizer(vocab); + llama_escape_whitespace(raw_text); + tokenizer.tokenize(raw_text, output); + } + else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + { + output.push_back(fragment.token); + } + } + } break; + case LLAMA_VOCAB_TYPE_BPE: + { + for (const auto & fragment: fragment_buffer) + { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) + { + auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + llm_tokenizer_bpe tokenizer(vocab); + tokenizer.tokenize(raw_text, output); + } + else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + { + output.push_back(fragment.token); + } + } + } break; + } + + return output; +} + +// +// grammar - internal +// + +struct llama_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct llama_grammar { + const std::vector> rules; + std::vector> stacks; + + // buffer for partially generated UTF-8 sequence from accepted tokens + llama_partial_utf8 partial_utf8; +}; + +struct llama_grammar_candidate { + size_t index; + const uint32_t * code_points; + llama_partial_utf8 partial_utf8; +}; + +// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as +// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. +static std::pair, llama_partial_utf8> decode_utf8( + const std::string & src, + llama_partial_utf8 partial_start) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; + const char * pos = src.c_str(); + std::vector code_points; + // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0. + code_points.reserve(src.size() + 1); + uint32_t value = partial_start.value; + int n_remain = partial_start.n_remain; + + // continue previous decode, if applicable + while (*pos != 0 && n_remain > 0) { + uint8_t next_byte = static_cast(*pos); + if ((next_byte >> 6) != 2) { + // invalid sequence, abort + code_points.push_back(0); + return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 }); + } + value = (value << 6) + (next_byte & 0x3F); + ++pos; + --n_remain; + } + + if (partial_start.n_remain > 0 && n_remain == 0) { + code_points.push_back(value); + } + + // decode any subsequent utf-8 sequences, which may end in an incomplete one + while (*pos != 0) { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + n_remain = lookup[highbits] - 1; + + if (n_remain < 0) { + // invalid sequence, abort + code_points.clear(); + code_points.push_back(0); + return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain }); + } + + uint8_t mask = (1 << (7 - n_remain)) - 1; + value = first_byte & mask; + ++pos; + while (*pos != 0 && n_remain > 0) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + ++pos; + --n_remain; + } + if (n_remain == 0) { + code_points.push_back(value); + } + } + code_points.push_back(0); + + return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); +} + +// returns true iff pos points to the end of one of the definitions of a rule +static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { + switch (pos->type) { + case LLAMA_GRETYPE_END: return true; // NOLINT + case LLAMA_GRETYPE_ALT: return true; // NOLINT + default: return false; + } +} + +// returns true iff chr satisfies the char range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static std::pair llama_grammar_match_char( + const llama_grammar_element * pos, + const uint32_t chr) { + + bool found = false; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + + GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT + + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + found = found || (pos->value <= chr && chr <= pos[1].value); + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + found = found || pos->value == chr; + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return std::make_pair(found == is_positive_char, pos); +} + +// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char +// range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static bool llama_grammar_match_partial_char( + const llama_grammar_element * pos, + const llama_partial_utf8 partial_utf8) { + + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); + + uint32_t partial_value = partial_utf8.value; + int n_remain = partial_utf8.n_remain; + + // invalid sequence or 7-bit char split across 2 bytes (overlong) + if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) { + return false; + } + + // range of possible code points this partial UTF-8 sequence could complete to + uint32_t low = partial_value << (n_remain * 6); + uint32_t high = low | ((1 << (n_remain * 6)) - 1); + + if (low == 0) { + if (n_remain == 2) { + low = 1 << 11; + } else if (n_remain == 3) { + low = 1 << 16; + } + } + + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + if (pos->value <= high && low <= pos[1].value) { + return is_positive_char; + } + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + if (low <= pos->value && pos->value <= high) { + return is_positive_char; + } + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return !is_positive_char; +} + + +// transforms a grammar pushdown stack into N possible stacks, all ending +// at a character range (terminal element) +static void llama_grammar_advance_stack( + const std::vector> & rules, + const std::vector & stack, + std::vector> & new_stacks) { + + if (stack.empty()) { + new_stacks.emplace_back(stack); + return; + } + + const llama_grammar_element * pos = stack.back(); + + switch (pos->type) { + case LLAMA_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast(pos->value); + const llama_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!llama_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + while (!llama_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; + } + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + new_stacks.emplace_back(stack); + break; + default: + // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range + // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + GGML_ASSERT(false); + } +} + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `llama_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +static std::vector> llama_grammar_accept( + const std::vector> & rules, + const std::vector> & stacks, + const uint32_t chr) { + + std::vector> new_stacks; + + for (const auto & stack : stacks) { + if (stack.empty()) { + continue; + } + + auto match = llama_grammar_match_char(stack.back(), chr); + if (match.first) { + const llama_grammar_element * pos = match.second; + + // update top of stack to next element, if any + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos)) { + new_stack.push_back(pos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + } + } + + return new_stacks; +} + +static std::vector llama_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates); + +static std::vector llama_grammar_reject_candidates_for_stack( + const std::vector> & rules, + const std::vector & stack, + const std::vector & candidates) { + + std::vector rejects; + + if (stack.empty()) { + for (const auto & tok : candidates) { + if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } + return rejects; + } + + const llama_grammar_element * stack_pos = stack.back(); + + std::vector next_candidates; + for (const auto & tok : candidates) { + if (*tok.code_points == 0) { + // reached end of full codepoints in token, reject iff it ended in a partial sequence + // that cannot satisfy this position in grammar + if (tok.partial_utf8.n_remain != 0 && + !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { + rejects.push_back(tok); + } + } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { + next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); + } else { + rejects.push_back(tok); + } + } + + const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + std::vector stack_after(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(stack_pos_after)) { + stack_after.push_back(stack_pos_after); + } + std::vector> next_stacks; + llama_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (const auto & tok : next_rejects) { + rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); + } + + return rejects; +} + +static std::vector llama_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates) { + GGML_ASSERT(!stacks.empty()); // REVIEW + + if (candidates.empty()) { + return std::vector(); + } + + auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + return rejects; +} + +// +// grammar - external +// + +struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { + const llama_grammar_element * pos; + + // copy rule definitions into vectors + std::vector> vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); + } + + // loop over alternates of start rule to build initial stacks + std::vector> stacks; + pos = rules[start_rule_index]; + do { + std::vector stack; + if (!llama_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + llama_grammar_advance_stack(vec_rules, stack, stacks); + while (!llama_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); + + return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; +} + +void llama_grammar_free(struct llama_grammar * grammar) { + delete grammar; +} + +struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { + llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 }; + + // redirect elements in stacks to point to new rules + for (size_t is = 0; is < result->stacks.size(); is++) { + for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { + for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { + for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { + if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { + result->stacks[is][ie] = &result->rules[ir0][ir1]; + } + } + } + } + } + + return result; +} + +// +// sampling +// + +void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + seed = time(NULL); + } + ctx->rng.seed(seed); +} + +void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { + GGML_ASSERT(candidates->size > 0); + + const int64_t t_start_sample_us = ggml_time_us(); + + // Sort the logits in descending order + if (!candidates->sorted) { + std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + candidates->sorted = true; + } + + float max_l = candidates->data[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { + float p = expf(candidates->data[i].logit - max_l); + candidates->data[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < candidates->size; ++i) { + candidates->data[i].p /= cum_sum; + } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep) { + const int64_t t_start_sample_us = ggml_time_us(); + + k = std::max(k, (int) min_keep); + k = std::min(k, (int) candidates->size); + + // Sort scores in descending order + if (!candidates->sorted) { + auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + if (k == (int) candidates->size) { + std::sort(candidates->data, candidates->data + candidates->size, comp); + } else { + std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); + } + candidates->sorted = true; + } + candidates->size = k; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + if (p >= 1.0f) { + return; + } + + llama_sample_softmax(ctx, candidates); + + const int64_t t_start_sample_us = ggml_time_us(); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = candidates->size; + + for (size_t i = 0; i < candidates->size; ++i) { + cum_sum += candidates->data[i].p; + + // Check if the running sum is at least p or if we have kept at least min_keep tokens + // we set the last index to i+1 to indicate that the current iterate should be included in the set + if (cum_sum >= p && i + 1 >= min_keep) { + last_idx = i + 1; + break; + } + } + + // Resize the output vector to keep only the top-p tokens + candidates->size = last_idx; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + if (p <= 0.0f || !candidates->size) { + return; + } + + llama_sample_softmax(ctx, candidates); + + const int64_t t_start_sample_us = ggml_time_us(); + + float scale = candidates->data[0].p; // scale by max prob + size_t i = 1; // first token always matches + + for (; i < candidates->size; ++i) { + if (candidates->data[i].p < p * scale && i >= min_keep) { + break; // prob too small + } + } + + // Resize the output vector to keep only the matching tokens + candidates->size = i; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { + if (z >= 1.0f || candidates->size <= 2) { + return; + } + + llama_sample_softmax(nullptr, candidates); + const int64_t t_start_sample_us = ggml_time_us(); + + // Compute the first and second derivatives + std::vector first_derivatives(candidates->size - 1); + std::vector second_derivatives(candidates->size - 2); + + for (size_t i = 0; i < first_derivatives.size(); ++i) { + first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p; + } + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; + } + + // Calculate absolute value of second derivatives + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = std::abs(second_derivatives[i]); + } + + // Normalize the second derivatives + { + const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); + + if (second_derivatives_sum > 1e-6f) { + for (float & value : second_derivatives) { + value /= second_derivatives_sum; + } + } else { + for (float & value : second_derivatives) { + value = 1.0f / second_derivatives.size(); + } + } + } + + float cum_sum = 0.0f; + size_t last_idx = candidates->size; + for (size_t i = 0; i < second_derivatives.size(); ++i) { + cum_sum += second_derivatives[i]; + + // Check if the running sum is greater than z or if we have kept at least min_keep tokens + if (cum_sum > z && i >= min_keep) { + last_idx = i; + break; + } + } + + // Resize the output vector to keep only the tokens above the tail location + candidates->size = last_idx; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + // Reference implementation: + // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr + if (p >= 1.0f) { + return; + } + + // Compute the softmax of logits and calculate entropy + llama_sample_softmax(nullptr, candidates); + + const int64_t t_start_sample_us = ggml_time_us(); + + float entropy = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { + entropy += -candidates->data[i].p * logf(candidates->data[i].p); + } + + // Compute the absolute difference between negative log probability and entropy for each candidate + std::vector shifted_scores; + for (size_t i = 0; i < candidates->size; ++i) { + float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy); + shifted_scores.push_back(shifted_score); + } + + // Sort tokens based on the shifted_scores and their corresponding indices + std::vector indices(candidates->size); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { + return shifted_scores[a] < shifted_scores[b]; + }); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = indices.size(); + + for (size_t i = 0; i < indices.size(); ++i) { + size_t idx = indices[i]; + cum_sum += candidates->data[idx].p; + + // Check if the running sum is greater than typical or if we have kept at least min_keep tokens + if (cum_sum > p && i >= min_keep - 1) { + last_idx = i + 1; + break; + } + } + + // Resize the output vector to keep only the locally typical tokens + std::vector new_candidates; + for (size_t i = 0; i < last_idx; ++i) { + size_t idx = indices[i]; + new_candidates.push_back(candidates->data[idx]); + } + + // Replace the data in candidates with the new_candidates data + std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); + candidates->size = new_candidates.size(); + candidates->sorted = false; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { + const int64_t t_start_sample_us = ggml_time_us(); + + for (size_t i = 0; i < candidates_p->size; ++i) { + candidates_p->data[i].logit /= temp; + } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { + llama_sample_temp(ctx, candidates_p, temp); +} + +void llama_sample_repetition_penalties( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present) { + if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { + return; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + // Create a frequency map to count occurrences of each token in last_tokens + std::unordered_map token_count; + for (size_t i = 0; i < penalty_last_n; ++i) { + token_count[last_tokens[i]]++; + } + + // Apply frequency and presence penalties to the candidates + for (size_t i = 0; i < candidates->size; ++i) { + const auto token_iter = token_count.find(candidates->data[i].id); + if (token_iter == token_count.end()) { + continue; + } + + const int count = token_iter->second; + + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. + // This is common fix for this problem, which is to multiply by the penalty instead of dividing. + if (candidates->data[i].logit <= 0) { + candidates->data[i].logit *= penalty_repeat; + } else { + candidates->data[i].logit /= penalty_repeat; + } + + candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present; + } + + candidates->sorted = false; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { + GGML_ASSERT(ctx); + const int64_t t_start_sample_us = ggml_time_us(); + + bool allow_eos = false; + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + allow_eos = true; + break; + } + } + + const llama_token eos = llama_token_eos(&ctx->model); + + std::vector, llama_partial_utf8>> candidates_decoded; + candidates_decoded.reserve(candidates->size); + std::vector candidates_grammar; + candidates_grammar.reserve(candidates->size); + + for (size_t i = 0; i < candidates->size; ++i) { + const llama_token id = candidates->data[i].id; + const std::string piece = llama_token_to_piece(ctx, id); + if (id == eos) { + if (!allow_eos) { + candidates->data[i].logit = -INFINITY; + } + } else if (piece.empty() || piece[0] == 0) { + candidates->data[i].logit = -INFINITY; + } else { + candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); + candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); + } + } + + const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); + for (const auto & reject : rejects) { + candidates->data[reject.index].logit = -INFINITY; + } + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; +} + +static void llama_log_softmax(float * array, size_t size) { + float max_l = *std::max_element(array, array + size); + float sum = 0.f; + for (size_t i = 0; i < size; ++i) { + float p = expf(array[i] - max_l); + sum += p; + array[i] = p; + } + + for (size_t i = 0; i < size; ++i) { + array[i] = logf(array[i] / sum); + } +} + +void llama_sample_classifier_free_guidance( + struct llama_context * ctx, + llama_token_data_array * candidates, + struct llama_context * guidance_ctx, + float scale) { + int64_t t_start_sample_us = ggml_time_us(); + + GGML_ASSERT(ctx); + + auto n_vocab = llama_n_vocab(llama_get_model(ctx)); + + GGML_ASSERT(n_vocab == (int)candidates->size); + GGML_ASSERT(!candidates->sorted); + + std::vector logits_base; + logits_base.reserve(candidates->size); + for (size_t i = 0; i < candidates->size; ++i) { + logits_base.push_back(candidates->data[i].logit); + } + llama_log_softmax(logits_base.data(), candidates->size); + + float* logits_guidance = llama_get_logits(guidance_ctx); + llama_log_softmax(logits_guidance, n_vocab); + + for (int i = 0; i < n_vocab; ++i) { + float logit_guidance = logits_guidance[i]; + float logit_base = logits_base[i]; + candidates->data[i].logit = scale * (logit_base - logit_guidance) + logit_guidance; + } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { + GGML_ASSERT(ctx); + + auto N = float(llama_n_vocab(llama_get_model(ctx))); + int64_t t_start_sample_us; + t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(nullptr, candidates); + + // Estimate s_hat using the most probable m tokens + float s_hat = 0.0; + float sum_ti_bi = 0.0; + float sum_ti_sq = 0.0; + for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { + float t_i = logf(float(i + 2) / float(i + 1)); + float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); + sum_ti_bi += t_i * b_i; + sum_ti_sq += t_i * t_i; + } + s_hat = sum_ti_bi / sum_ti_sq; + + // Compute k from the estimated s_hat and target surprise value + float epsilon_hat = s_hat - 1; + float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); + + // Sample the next word X using top-k sampling + llama_sample_top_k(nullptr, candidates, int(k), 1); + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + llama_token X = llama_sample_token(ctx, candidates); + t_start_sample_us = ggml_time_us(); + + // Compute error as the difference between observed surprise and target surprise value + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates->data[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + *mu = *mu - eta * e; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + return X; +} + +llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { + int64_t t_start_sample_us; + t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(ctx, candidates); + + // Truncate the words with surprise values greater than mu + candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > *mu; + })); + + if (candidates->size == 0) { + candidates->size = 1; + } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + + // Normalize the probabilities of the remaining words + llama_sample_softmax(ctx, candidates); + + // Sample the next word X from the remaining words + llama_token X = llama_sample_token(ctx, candidates); + t_start_sample_us = ggml_time_us(); + + // Compute error as the difference between observed surprise and target surprise value + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates->data[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + *mu = *mu - eta * e; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + return X; +} + +llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { + const int64_t t_start_sample_us = ggml_time_us(); + + // Find max element + auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit < b.logit; + }); + + llama_token result = max_iter->id; + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; + } + return result; +} + +llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { + GGML_ASSERT(ctx); + + const int64_t t_start_sample_us = ggml_time_us(); + llama_sample_softmax(nullptr, candidates); + + std::vector probs; + probs.reserve(candidates->size); + for (size_t i = 0; i < candidates->size; ++i) { + probs.push_back(candidates->data[i].p); + } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + auto & rng = ctx->rng; + int idx = dist(rng); + + llama_token result = candidates->data[idx].id; + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; + return result; +} + +void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { + const int64_t t_start_sample_us = ggml_time_us(); + + if (token == llama_token_eos(&ctx->model)) { + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + return; + } + } + GGML_ASSERT(false); + } + + const std::string piece = llama_token_to_piece(ctx, token); + + // Note terminating 0 in decoded string + const auto decoded = decode_utf8(piece, grammar->partial_utf8); + const auto & code_points = decoded.first; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); + } + grammar->partial_utf8 = decoded.second; + GGML_ASSERT(!grammar->stacks.empty()); + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; +} + +// +// Beam search +// + +struct llama_beam { + std::vector tokens; + float p; // Cumulative beam probability (renormalized relative to all beams) + bool eob; // Initialize end-of-beam to false. Callback sets this to true. + // Sort beams by probability. In case of ties, prefer beams at eob. + bool operator<(const llama_beam & rhs) const { + return std::make_pair(p, eob) < std::make_pair(rhs.p, rhs.eob); + } + // Shift off first n tokens and discard them. + void shift_tokens(const size_t n) { + if (n) { + std::copy(tokens.begin() + n, tokens.end(), tokens.begin()); + tokens.resize(tokens.size() - n); + } + } + llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eob}; } +}; + +// A struct for calculating logit-related info. +struct llama_logit_info { + const float * const logits; + const int n_vocab; + const float max_l; + const float normalizer; + struct sum_exp { + float max_l; + float operator()(float sum, float l) const { return sum + std::exp(l - max_l); } + }; + llama_logit_info(llama_context * ctx) + : logits(llama_get_logits(ctx)) + , n_vocab(llama_n_vocab(llama_get_model(ctx))) + , max_l(*std::max_element(logits, logits + n_vocab)) + , normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l})) + { } + llama_token_data get_token_data(const llama_token token_id) const { + constexpr auto p = std::numeric_limits::quiet_NaN(); // never used + return {token_id, logits[token_id], p}; + } + // Return top k token_data by logit. + std::vector top_k(size_t k) { + std::vector min_heap; // min-heap by logit + const llama_token k_min = std::min(static_cast(k), n_vocab); + min_heap.reserve(k_min); + for (llama_token token_id = 0 ; token_id < k_min ; ++token_id) { + min_heap.push_back(get_token_data(token_id)); + } + auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; + std::make_heap(min_heap.begin(), min_heap.end(), comp); + for (llama_token token_id = k_min ; token_id < n_vocab ; ++token_id) { + if (min_heap.front().logit < logits[token_id]) { + std::pop_heap(min_heap.begin(), min_heap.end(), comp); + min_heap.back().id = token_id; + min_heap.back().logit = logits[token_id]; + std::push_heap(min_heap.begin(), min_heap.end(), comp); + } + } + return min_heap; + } + float probability_from_logit(float logit) const { + return normalizer * std::exp(logit - max_l); + } +}; + +struct llama_beam_search_data { + llama_context * ctx; + size_t n_beams; + int n_past; + int n_predict; + std::vector beams; + std::vector next_beams; + + // Re-calculated on each loop iteration + size_t common_prefix_length; + + // Used to communicate to/from callback on beams state. + std::vector beam_views; + + llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict) + : ctx(ctx) + , n_beams(n_beams) + , n_past(n_past) + , n_predict(n_predict) + , beam_views(n_beams) { + beams.reserve(n_beams); + next_beams.reserve(n_beams); + } + + // Collapse beams to a single beam given by index. + void collapse_beams(const size_t beam_idx) { + if (0u < beam_idx) { + std::swap(beams[0], beams[beam_idx]); + } + beams.resize(1); + } + + // Min-heaps are used to efficiently collect the top-k elements (k=n_beams). + // The repetitive patterns below reflect the 2 stages of heaps: + // * Gather elements until the vector is full, then call std::make_heap() on it. + // * If the heap is full and a new element is found that should be included, pop the + // least element to the back(), replace it with the new, then push it into the heap. + void fill_next_beams_by_top_probabilities(llama_beam & beam) { + // Min-heaps use a greater-than comparator. + const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p; }; + if (beam.eob) { + // beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough. + if (next_beams.size() < n_beams) { + next_beams.push_back(std::move(beam)); + if (next_beams.size() == n_beams) { + std::make_heap(next_beams.begin(), next_beams.end(), comp); + } + } else if (next_beams.front().p < beam.p) { + std::pop_heap(next_beams.begin(), next_beams.end(), comp); + next_beams.back() = std::move(beam); + std::push_heap(next_beams.begin(), next_beams.end(), comp); + } + } else { + // beam is not at end-of-sentence, so branch with next top_k tokens. + if (!beam.tokens.empty()) { + llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0)); + } + llama_logit_info logit_info(ctx); + std::vector next_tokens = logit_info.top_k(n_beams); + size_t i=0; + if (next_beams.size() < n_beams) { + for (; next_beams.size() < n_beams ; ++i) { + llama_beam next_beam = beam; + next_beam.tokens.push_back(next_tokens[i].id); + next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit); + next_beams.push_back(std::move(next_beam)); + } + std::make_heap(next_beams.begin(), next_beams.end(), comp); + } else { + for (; next_beams.front().p == 0.0f ; ++i) { + std::pop_heap(next_beams.begin(), next_beams.end(), comp); + next_beams.back() = beam; + next_beams.back().tokens.push_back(next_tokens[i].id); + next_beams.back().p *= logit_info.probability_from_logit(next_tokens[i].logit); + std::push_heap(next_beams.begin(), next_beams.end(), comp); + } + } + for (; i < n_beams ; ++i) { + const float next_p = beam.p * logit_info.probability_from_logit(next_tokens[i].logit); + if (next_beams.front().p < next_p) { + std::pop_heap(next_beams.begin(), next_beams.end(), comp); + next_beams.back() = beam; + next_beams.back().tokens.push_back(next_tokens[i].id); + next_beams.back().p = next_p; + std::push_heap(next_beams.begin(), next_beams.end(), comp); + } + } + } + } + + // Find common_prefix_length based on beams. + // Requires beams is not empty. + size_t find_common_prefix_length() { + size_t common_prefix_length = beams[0].tokens.size(); + for (size_t i = 1 ; i < beams.size() ; ++i) { + common_prefix_length = std::min(common_prefix_length, beams[i].tokens.size()); + for (size_t j = 0 ; j < common_prefix_length ; ++j) { + if (beams[0].tokens[j] != beams[i].tokens[j]) { + common_prefix_length = j; + break; + } + } + } + return common_prefix_length; + } + + // Construct beams_state to send back to caller via the callback function. + // Side effect: set common_prefix_length = find_common_prefix_length(); + llama_beams_state get_beams_state(const bool last_call) { + for (size_t i = 0 ; i < beams.size() ; ++i) { + beam_views[i] = beams[i].view(); + } + common_prefix_length = find_common_prefix_length(); + return {beam_views.data(), beams.size(), common_prefix_length, last_call}; + } + + // Loop: + // * while i < n_predict, AND + // * any of the beams have not yet reached end-of-beam (eob), AND + // * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence + // (since all other beam probabilities can only decrease) + void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) { + beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eob. + const auto not_eob = [](const llama_beam & beam) { return !beam.eob; }; + for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eob) && + !beams[top_beam_index()].eob ; ++i) { + callback(callback_data, get_beams_state(false)); // Sets common_prefix_length + update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed. + if (common_prefix_length) { + llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0)); + n_past += common_prefix_length; + } + // Zero-out next_beam probabilities to place them last in following min-heap. + std::for_each(next_beams.begin(), next_beams.end(), [](llama_beam & beam) { beam.p = 0.0f; }); + for (llama_beam & beam : beams) { + beam.shift_tokens(common_prefix_length); + fill_next_beams_by_top_probabilities(beam); + } + // next_beams become the beams of next/final iteration. Swap them to re-use memory. + beams.swap(next_beams); + renormalize_beam_probabilities(beams); + } + collapse_beams(top_beam_index()); + callback(callback_data, get_beams_state(true)); + } + + // As beams grow, the cumulative probabilities decrease. + // Renormalize them to avoid floating point underflow. + static void renormalize_beam_probabilities(std::vector & beams) { + const auto sum_p = [](float sum, llama_beam & beam) { return sum + beam.p; }; + const float inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p); + std::for_each(beams.begin(), beams.end(), [=](llama_beam & beam) { beam.p *= inv_sum; }); + } + + // Assumes beams is non-empty. Uses llama_beam::operator<() for ordering. + size_t top_beam_index() { + return std::max_element(beams.begin(), beams.end()) - beams.begin(); + } + + // Copy (p,eob) for each beam which may have been changed by the callback. + void update_beams_from_beam_views() { + for (size_t i = 0 ; i < beams.size() ; ++i) { + beams[i].p = beam_views[i].p; + beams[i].eob = beam_views[i].eob; + } + } +}; + +void llama_beam_search(llama_context * ctx, + llama_beam_search_callback_fn_t callback, void * callback_data, + size_t n_beams, int n_past, int n_predict) { + assert(ctx); + const int64_t t_start_sample_us = ggml_time_us(); + + llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict); + + beam_search_data.loop(callback, callback_data); + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; +} + +// +// quantization +// + +template +struct no_init { + T value; + no_init() { /* do nothing */ } +}; + +struct quantize_state_internal { + const llama_model & model; + const llama_model_quantize_params * params; + + int n_attention_wv = 0; + int n_feed_forward_w2 = 0; + int i_attention_wv = 0; + int i_feed_forward_w2 = 0; + + int n_k_quantized = 0; + int n_fallback = 0; + + quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params) + : model(model) + , params(params) + {} +}; + +static void llama_convert_tensor_internal( + struct ggml_tensor * tensor, std::vector> & output, std::vector & workers, + const size_t nelements, const int nthread +) { + if (output.size() < nelements) { + output.resize(nelements); + } + float * f32_output = (float *) output.data(); + + ggml_type_traits_t qtype; + if (ggml_is_quantized(tensor->type)) { + qtype = ggml_internal_get_type_traits(tensor->type); + if (qtype.to_float == NULL) { + throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type))); + } + } else if (tensor->type != GGML_TYPE_F16) { + throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type))); + } + + if (nthread < 2) { + if (tensor->type == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements); + } else if (ggml_is_quantized(tensor->type)) { + qtype.to_float(tensor->data, f32_output, nelements); + } else { + GGML_ASSERT(false); // unreachable + } + return; + } + + size_t block_size = tensor->type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor->type); + size_t block_size_bytes = ggml_type_size(tensor->type); + + GGML_ASSERT(nelements % block_size == 0); + size_t nblocks = nelements / block_size; + size_t blocks_per_thread = nblocks / nthread; + size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count + + size_t in_buff_offs = 0; + size_t out_buff_offs = 0; + + for (int tnum = 0; tnum < nthread; tnum++) { + size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread + size_t thr_elems = thr_blocks * block_size; // number of elements for this thread + size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread + + auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) { + if (typ == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels); + } else { + qtype.to_float(inbuf, outbuf, nels); + } + }; + workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems); + in_buff_offs += thr_block_bytes; + out_buff_offs += thr_elems; + } + for (auto & w : workers) { w.join(); } + workers.clear(); +} + +static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) { + const std::string name = ggml_get_name(tensor); + + // TODO: avoid hardcoded tensor names - use the TN_* constants + const llm_arch arch = qs.model.arch; + const auto tn = LLM_TN(arch); + + auto use_more_bits = [](int i_layer, int num_layers) -> bool { + return i_layer < num_layers/8 || i_layer >= 7*num_layers/8 || (i_layer - num_layers/8)%3 == 2; + }; + + if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { + int nx = tensor->ne[0]; + if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) { + new_type = GGML_TYPE_Q8_0; + } + else if (new_type != GGML_TYPE_Q8_0) { + new_type = GGML_TYPE_Q6_K; + } + } else if (name.find("attn_v.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { + new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; + else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && + use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K; + else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) && + (qs.i_attention_wv < qs.n_attention_wv/8 || qs.i_attention_wv >= 7*qs.n_attention_wv/8)) new_type = GGML_TYPE_Q6_K; + if (qs.model.type == MODEL_70B) { + // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is + // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with + // nearly negligible increase in model size by quantizing this tensor with more bits: + if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K; + } + if (qs.model.hparams.n_expert == 8) { + // for the 8-expert model, bumping this to Q8_0 trades just ~128MB + // TODO: explore better strategies + new_type = GGML_TYPE_Q8_0; + } + ++qs.i_attention_wv; + } else if (name.find("attn_k.weight") != std::string::npos) { + if (qs.model.hparams.n_expert == 8) { + // for the 8-expert model, bumping this to Q8_0 trades just ~128MB + // TODO: explore better strategies + new_type = GGML_TYPE_Q8_0; + } + } else if (name.find("ffn_down.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { + new_type = qs.i_feed_forward_w2 < 2 ? GGML_TYPE_Q5_K + : arch != LLM_ARCH_FALCON || use_more_bits(qs.i_feed_forward_w2, qs.n_feed_forward_w2) ? GGML_TYPE_Q4_K + : GGML_TYPE_Q3_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) { + new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) { + if (arch == LLM_ARCH_FALCON) { + new_type = qs.i_feed_forward_w2 < 2 ? GGML_TYPE_Q6_K : + use_more_bits(qs.i_feed_forward_w2, qs.n_feed_forward_w2) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; + } else { + if (use_more_bits(qs.i_feed_forward_w2, qs.n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; + } + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(qs.i_feed_forward_w2, qs.n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && qs.i_feed_forward_w2 < 4) { + new_type = GGML_TYPE_Q5_K; + } + ++qs.i_feed_forward_w2; + } else if (name.find("attn_output.weight") != std::string::npos) { + if (arch != LLM_ARCH_FALCON) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; + } else { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; + } + } + else if (name.find("attn_qkv.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K; + } + else if (name.find("ffn_gate.weight") != std::string::npos || name.find("ffn_up.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + } + // This can be used to reduce the size of the Q5_K_S model. + // The associated PPL increase is fully in line with the size reduction + //else { + // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; + //} + bool convert_incompatible_tensor = false; + if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || + new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) { + int nx = tensor->ne[0]; + int ny = tensor->ne[1]; + if (nx % QK_K != 0) { + LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type)); + convert_incompatible_tensor = true; + } else { + ++qs.n_k_quantized; + } + } + if (convert_incompatible_tensor) { + switch (new_type) { + case GGML_TYPE_Q2_K: new_type = GGML_TYPE_Q4_0; break; + case GGML_TYPE_Q3_K: new_type = GGML_TYPE_Q4_1; break; + case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; + case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; + case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; + default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); + } + LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); + ++qs.n_fallback; + } + + return new_type; +} + +static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { + ggml_type quantized_type; + llama_ftype ftype = params->ftype; + + switch (params->ftype) { + case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break; + case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break; + case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; + case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; + case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; + case LLAMA_FTYPE_MOSTLY_F16: quantized_type = GGML_TYPE_F16; break; + case LLAMA_FTYPE_ALL_F32: quantized_type = GGML_TYPE_F32; break; + + // K-quants + case LLAMA_FTYPE_MOSTLY_Q2_K: quantized_type = GGML_TYPE_Q2_K; break; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: + case LLAMA_FTYPE_MOSTLY_Q3_K_M: + case LLAMA_FTYPE_MOSTLY_Q3_K_L: quantized_type = GGML_TYPE_Q3_K; break; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: + case LLAMA_FTYPE_MOSTLY_Q4_K_M: quantized_type = GGML_TYPE_Q4_K; break; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: + case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break; + case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break; + + default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); + } + + int nthread = params->nthread; + + if (nthread <= 0) { + nthread = std::thread::hardware_concurrency(); + } + + // mmap consistently increases speed Linux, and also increases speed on Windows with + // hot cache. It may cause a slowdown on macOS, possibly related to free memory. +#if defined(__linux__) || defined(_WIN32) + constexpr bool use_mmap = true; +#else + constexpr bool use_mmap = false; +#endif + + llama_model_loader ml(fname_inp, use_mmap, NULL); + if (ml.use_mmap) { + ml.mapping.reset(new llama_mmap(&ml.file, /* prefetch */ 0, ggml_is_numa())); + } + + llama_model model; + llm_load_arch(ml, model); + llm_load_hparams(ml, model); + + struct quantize_state_internal qs(model, params); + + if (params->only_copy) { + ftype = model.ftype; + } + + const size_t align = GGUF_DEFAULT_ALIGNMENT; + struct gguf_context * ctx_out = gguf_init_empty(); + + // copy the KV pairs from the input file + gguf_set_kv (ctx_out, ml.ctx_gguf); + gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); + gguf_set_val_u32(ctx_out, "general.file_type", ftype); + + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * meta = ml.get_tensor_meta(i); + + const std::string name = ggml_get_name(meta); + + // TODO: avoid hardcoded tensor names - use the TN_* constants + if (name.find("attn_v.weight") != std::string::npos || name.find("attn_qkv.weight") != std::string::npos) { + ++qs.n_attention_wv; + } + else if (name.find("ffn_down.weight") != std::string::npos) { + ++qs.n_feed_forward_w2; + } + } + if (qs.n_attention_wv != qs.n_feed_forward_w2 || (uint32_t)qs.n_attention_wv != model.hparams.n_layer) { + LLAMA_LOG_WARN("%s ============ Strange model: n_attention_wv = %d, n_feed_forward_w2 = %d, hparams.n_layer = %d\n", + __func__, qs.n_attention_wv, qs.n_feed_forward_w2, model.hparams.n_layer); + } + + size_t total_size_org = 0; + size_t total_size_new = 0; + std::vector hist_all(1 << 4, 0); + + std::vector workers; + workers.reserve(nthread); + std::mutex mutex; + + int idx = 0; + + std::vector> read_data; + std::vector> work; + std::vector> f32_conv_buf; + + // populate the original tensors so we get an initial meta data + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * meta = ml.get_tensor_meta(i); + gguf_add_tensor(ctx_out, meta); + } + + std::ofstream fout(fname_out, std::ios::binary); + fout.exceptions(std::ofstream::failbit); // fail fast on write errors + + const size_t meta_size = gguf_get_meta_size(ctx_out); + + LLAMA_LOG_INFO("%s: meta size = %zu bytes\n", __func__, meta_size); + + // placeholder for the meta data + ::zeros(fout, meta_size); + + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * tensor = ml.get_tensor_meta(i); + + const std::string name = ggml_get_name(tensor); + + if (!ml.use_mmap) { + if (read_data.size() < ggml_nbytes(tensor)) { + read_data.resize(ggml_nbytes(tensor)); + } + tensor->data = read_data.data(); + } + ml.load_data_for(tensor); + + LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", + ++idx, ml.n_tensors, + ggml_get_name(tensor), + llama_format_tensor_shape(tensor).c_str(), + ggml_type_name(tensor->type)); + + // This used to be a regex, but has an extreme cost to compile times. + bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + + // quantize only 2D tensors + quantize &= (ggml_n_dims(tensor) == 2); + quantize &= params->quantize_output_tensor || name != "output.weight"; + quantize &= !params->only_copy; + + // do not quantize expert gating tensors + quantize &= name.find("ffn_gate_inp.weight") == std::string::npos; + + enum ggml_type new_type; + void * new_data; + size_t new_size; + + if (quantize) { + new_type = quantized_type; + if (!params->pure) { + new_type = get_k_quant_type(qs, new_type, tensor, ftype); + } + + // If we've decided to quantize to the same type the tensor is already + // in then there's nothing to do. + quantize = tensor->type != new_type; + } + if (!quantize) { + new_type = tensor->type; + new_data = tensor->data; + new_size = ggml_nbytes(tensor); + LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0); + } else { + const size_t nelements = ggml_nelements(tensor); + + float * f32_data; + + if (tensor->type == GGML_TYPE_F32) { + f32_data = (float *) tensor->data; + } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { + throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); + } else { + llama_convert_tensor_internal(tensor, f32_conv_buf, workers, nelements, nthread); + f32_data = (float *) f32_conv_buf.data(); + } + + LLAMA_LOG_INFO("quantizing to %s .. ", ggml_type_name(new_type)); + fflush(stdout); + + if (work.size() < nelements * 4) { + work.resize(nelements * 4); // upper bound on size + } + new_data = work.data(); + std::array hist_cur = {}; + + static const int chunk_size = 32 * 512; + const int nchunk = (nelements + chunk_size - 1)/chunk_size; + const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1; + if (nthread_use < 2) { + new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nelements, hist_cur.data()); + } else { + size_t counter = 0; + new_size = 0; + auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, nelements]() { + std::array local_hist = {}; + size_t local_size = 0; + while (true) { + std::unique_lock lock(mutex); + size_t first = counter; counter += chunk_size; + if (first >= nelements) { + if (local_size > 0) { + for (int j=0; j %8.2f MiB | hist: ", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); + int64_t tot_count = 0; + for (size_t i = 0; i < hist_cur.size(); i++) { + hist_all[i] += hist_cur[i]; + tot_count += hist_cur[i]; + } + + if (tot_count > 0) { + for (size_t i = 0; i < hist_cur.size(); i++) { + LLAMA_LOG_INFO("%5.3f ", hist_cur[i] / float(nelements)); + } + } + LLAMA_LOG_INFO("\n"); + } + total_size_org += ggml_nbytes(tensor); + total_size_new += new_size; + + // update the gguf meta data as we go + gguf_set_tensor_type(ctx_out, name.c_str(), new_type); + gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size); + + // write tensor data + padding + fout.write((const char *) new_data, new_size); + zeros(fout, GGML_PAD(new_size, align) - new_size); + } + + // go back to beginning of file and write the updated meta data + { + fout.seekp(0); + std::vector data(gguf_get_meta_size(ctx_out)); + gguf_get_meta_data(ctx_out, data.data()); + fout.write((const char *) data.data(), data.size()); + } + + fout.close(); + + gguf_free(ctx_out); + + LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); + LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); + + // print histogram for all tensors + { + int64_t sum_all = 0; + for (size_t i = 0; i < hist_all.size(); i++) { + sum_all += hist_all[i]; + } + + if (sum_all > 0) { + LLAMA_LOG_INFO("%s: hist: ", __func__); + for (size_t i = 0; i < hist_all.size(); i++) { + LLAMA_LOG_INFO("%5.3f ", hist_all[i] / float(sum_all)); + } + LLAMA_LOG_INFO("\n"); + } + } + + if (qs.n_fallback > 0) { + LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) incompatible with k-quants and required fallback quantization\n", + __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback); + } +} + +static int llama_apply_lora_from_file_internal( + const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads +) { + LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); + + const int64_t t_start_lora_us = ggml_time_us(); + + llama_file fin(path_lora, "rb"); + + // verify magic and version + { + uint32_t magic = fin.read_u32(); + if (magic != LLAMA_FILE_MAGIC_GGLA) { + LLAMA_LOG_ERROR("%s: bad file magic\n", __func__); + return 1; + } + + uint32_t format_version = fin.read_u32(); + if (format_version != 1) { + LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ ); + return 1; + } + } + + int32_t lora_r = fin.read_u32(); + int32_t lora_alpha = fin.read_u32(); + float scaling = scale * (float)lora_alpha / (float)lora_r; + + LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling); + + // create a name -> tensor map of the model to accelerate lookups + // find the max tensor size to estimate the required temporary buffer size + size_t max_tensor_size = 0; + std::unordered_map model_tensors; + for (const auto & kv : model.tensors_by_name) { + model_tensors.insert(kv); + size_t f32_size = ggml_nelements(kv.second) * sizeof(float); + max_tensor_size = std::max(max_tensor_size, f32_size); + } + + // create a temporary ggml context to store the lora tensors + // TODO: use ggml-alloc + size_t lora_ctx_size = max_tensor_size * 3; + LLAMA_LOG_INFO("%s: allocating %.f MB for lora temporary buffer\n", __func__, lora_ctx_size / 1024.0 / 1024.0); + std::vector lora_buf(lora_ctx_size); + + struct ggml_init_params params; + params.mem_size = lora_buf.size(); + params.mem_buffer = lora_buf.data(); + params.no_alloc = false; + + using unique_context = std::unique_ptr; + + unique_context lora_ctx(nullptr, ggml_free); + lora_ctx.reset(ggml_init(params)); + std::unordered_map lora_tensors; + + // load base model + std::unique_ptr ml; + + unique_context base_ctx(nullptr, ggml_free); + std::vector base_buf; + if (path_base_model) { + LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model); + ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/ NULL)); + + size_t ctx_size; + size_t mmapped_size; + ml->calc_sizes(ctx_size, mmapped_size); + + base_buf.resize(ctx_size); + + ggml_init_params base_params; + base_params.mem_size = base_buf.size(); + base_params.mem_buffer = base_buf.data(); + base_params.no_alloc = ml->use_mmap; + + base_ctx.reset(ggml_init(base_params)); + + // maybe this should be in llama_model_loader + if (ml->use_mmap) { + ml->mapping.reset(new llama_mmap(&ml->file, /* prefetch */ 0, ggml_is_numa())); + } + } + + // read tensors and apply + bool warned = false; + int n_tensors = 0; + + std::vector work_buffer; + + while (true) { + if (fin.tell() == fin.size) { + // eof + break; + } + + int32_t n_dims; + int32_t name_len; + int32_t ftype; + + fin.read_raw(&n_dims, sizeof(n_dims)); + fin.read_raw(&name_len, sizeof(name_len)); + fin.read_raw(&ftype, sizeof(ftype)); + + if (n_dims != 1 && n_dims != 2) { + LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims); + return 1; + } + + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read_raw(&ne[i], sizeof(ne[i])); + } + + std::string name; + { + GGML_ASSERT(name_len <= 1024); + char buf[1024]; + fin.read_raw(buf, name_len); + name = std::string(buf, name_len); + } + + // check for lora suffix and get the type of tensor + const std::string lora_suffix = ".lora"; + size_t pos = name.rfind(lora_suffix); + if (pos == std::string::npos) { + LLAMA_LOG_ERROR("%s: error: '%s' is not a lora tensor\n", __func__, name.c_str()); + return 1; + } + + std::string lora_type = name.substr(pos + lora_suffix.length()); + std::string base_name = name; + base_name.erase(pos); + // LLAMA_LOG_INFO("%s: %s => %s (lora type %s) \n", __func__, name.c_str(), base_name.c_str(), lora_type.c_str()); + + if (model_tensors.find(base_name) == model_tensors.end()) { + LLAMA_LOG_ERROR("%s: unknown tensor '%s' in lora adapter\n", __func__, name.data()); + return 1; + } + + // create ggml tensor + ggml_type wtype; + switch (ftype) { + case 0: wtype = GGML_TYPE_F32; break; + case 1: wtype = GGML_TYPE_F16; break; + default: + { + LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n", + __func__, ftype); + return false; + } + } + ggml_tensor * lora_tensor = ggml_new_tensor_2d(lora_ctx.get(), wtype, ne[0], ne[1]); + ggml_set_name(lora_tensor, name.c_str()); + + // load tensor data + size_t offset = fin.tell(); + size_t tensor_data_size = ggml_nbytes(lora_tensor); + offset = (offset + 31) & -32; + fin.seek(offset, SEEK_SET); + fin.read_raw(lora_tensor->data, tensor_data_size); + + lora_tensors[name] = lora_tensor; + + // check if we have both A and B tensors and apply + if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() && + lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) { + + ggml_tensor * dest_t = model_tensors[base_name]; + + offload_func_t offload_func = ggml_offload_nop; + offload_func_t offload_func_force_inplace = ggml_offload_nop; + +#ifdef GGML_USE_CUBLAS + if (dest_t->backend == GGML_BACKEND_GPU || dest_t->backend == GGML_BACKEND_GPU_SPLIT) { + if (dest_t->type != GGML_TYPE_F16) { + throw std::runtime_error(format( + "%s: error: the simultaneous use of LoRAs and GPU acceleration is only supported for f16 models. dest_t->type: %d", __func__, dest_t->type)); + } + offload_func = ggml_cuda_assign_buffers; + offload_func_force_inplace = ggml_cuda_assign_buffers_force_inplace; + } +#endif // GGML_USE_CUBLAS + + ggml_tensor * base_t; + if (ml) { + struct gguf_context * ctx_gguf = ml->ctx_gguf; + + // load from base model + if (gguf_find_tensor(ctx_gguf, base_name.c_str()) < 0) { + LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str()); + return 1; + } + + base_t = ml->create_tensor(base_ctx.get(), base_name, { dest_t->ne[0], dest_t->ne[1] }, GGML_BACKEND_CPU); + ml->load_data_for(base_t); + } else { + base_t = dest_t; + } + + if (ggml_is_quantized(base_t->type)) { + if (!warned) { + LLAMA_LOG_WARN("%s: warning: using a lora adapter with a quantized model may result in poor quality, " + "use a f16 or f32 base model with --lora-base\n", __func__); + warned = true; + } + } + + ggml_tensor * loraA = lora_tensors[base_name + ".loraA"]; + GGML_ASSERT(loraA->type == GGML_TYPE_F32); + ggml_set_name(loraA, "loraA"); + + ggml_tensor * loraB = lora_tensors[base_name + ".loraB"]; + GGML_ASSERT(loraB->type == GGML_TYPE_F32); + ggml_set_name(loraB, "loraB"); + + if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) { + LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");" + " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]); + return 1; + } + + // w = w + BA*s + ggml_tensor * BA = ggml_mul_mat(lora_ctx.get(), loraA, loraB); + offload_func(BA); + ggml_set_name(BA, "BA"); + + if (scaling != 1.0f) { + ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx.get(), scaling); + ggml_set_name(scale_tensor, "scale_tensor"); + + BA = ggml_scale_inplace(lora_ctx.get(), BA, scale_tensor); + offload_func(BA); + ggml_set_name(BA, "BA_scaled"); + } + + ggml_tensor * r; + if (base_t == dest_t) { + r = ggml_add_inplace(lora_ctx.get(), dest_t, BA); + offload_func_force_inplace(r); + ggml_set_name(r, "r_add_inplace"); + } + else { + r = ggml_add(lora_ctx.get(), base_t, BA); + offload_func(r); + ggml_set_name(r, "r_add"); + + r = ggml_cpy(lora_ctx.get(), r, dest_t); + offload_func(r); + ggml_set_name(r, "r_cpy"); + } + + struct ggml_cgraph * gf = ggml_new_graph(lora_ctx.get()); + ggml_build_forward_expand(gf, r); + + ggml_graph_compute_helper(work_buffer, gf, n_threads); + + // the tensors in the adapter must be sorted such that loraA and loraB of the same tensor are next to each other + GGML_ASSERT(lora_tensors.size() == 2); + + // we won't need these tensors again, reset the context to save memory + lora_ctx.reset(ggml_init(params)); + lora_tensors.clear(); + + n_tensors++; + if (n_tensors % 4 == 0) { + LLAMA_LOG_INFO("."); + } + } + } + + const int64_t t_lora_us = ggml_time_us() - t_start_lora_us; + LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0); + + return 0; +} + +// +// interface implementation +// +struct llama_model_params llama_model_default_params() { + struct llama_model_params result = { + /*.n_gpu_layers =*/ 0, + /*.main_gpu =*/ 0, + /*.tensor_split =*/ nullptr, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + /*.kv_overrides =*/ nullptr, + /*.vocab_only =*/ false, + /*.use_mmap =*/ true, + /*.use_mlock =*/ false, + }; + +#ifdef GGML_USE_METAL + result.n_gpu_layers = 1; +#endif + + return result; +} + +struct llama_context_params llama_context_default_params() { + struct llama_context_params result = { + /*.seed =*/ LLAMA_DEFAULT_SEED, + /*.n_ctx =*/ 512, + /*.n_batch =*/ 512, + /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default + /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, + /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_UNSPECIFIED, + /*.rope_freq_base =*/ 0.0f, + /*.rope_freq_scale =*/ 0.0f, + /*.yarn_ext_factor =*/ -1.0f, + /*.yarn_attn_factor =*/ 1.0f, + /*.yarn_beta_fast =*/ 32.0f, + /*.yarn_beta_slow =*/ 1.0f, + /*.yarn_orig_ctx =*/ 0, + /*.type_k =*/ GGML_TYPE_F16, + /*.type_v =*/ GGML_TYPE_F16, + /*.mul_mat_q =*/ true, + /*.logits_all =*/ false, + /*.embedding =*/ false, + /*.offload_kqv =*/ true, + }; + + return result; +} + +struct llama_model_quantize_params llama_model_quantize_default_params() { + struct llama_model_quantize_params result = { + /*.nthread =*/ 0, + /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, + /*.allow_requantize =*/ false, + /*.quantize_output_tensor =*/ true, + /*.only_copy =*/ false, + /*.pure =*/ false, + }; + + return result; +} + +int llama_max_devices(void) { + return LLAMA_MAX_DEVICES; +} + +bool llama_mmap_supported(void) { + return llama_mmap::SUPPORTED; +} + +bool llama_mlock_supported(void) { + return llama_mlock::SUPPORTED; +} + +void llama_backend_init(bool numa) { + ggml_time_init(); + + // needed to initialize f16 tables + { + struct ggml_init_params params = { 0, NULL, false }; + struct ggml_context * ctx = ggml_init(params); + ggml_free(ctx); + } + + if (numa) { + ggml_numa_init(); + } + +#ifdef GGML_USE_MPI + ggml_mpi_backend_init(); +#endif +} + +void llama_backend_free(void) { +#ifdef GGML_USE_MPI + ggml_mpi_backend_free(); +#endif +} + +int64_t llama_time_us(void) { + return ggml_time_us(); +} + +struct llama_model * llama_load_model_from_file( + const char * path_model, + struct llama_model_params params) { + ggml_time_init(); + + llama_model * model = new llama_model; + + unsigned cur_percentage = 0; + if (params.progress_callback == NULL) { + params.progress_callback_user_data = &cur_percentage; + params.progress_callback = [](float progress, void * ctx) { + unsigned * cur_percentage_p = (unsigned *) ctx; + unsigned percentage = (unsigned) (100 * progress); + while (percentage > *cur_percentage_p) { + *cur_percentage_p = percentage; + LLAMA_LOG_INFO("."); + if (percentage >= 100) { + LLAMA_LOG_INFO("\n"); + } + } + }; + } + + if (!llama_model_load(path_model, *model, params)) { + LLAMA_LOG_ERROR("%s: failed to load model\n", __func__); + delete model; + return nullptr; + } + + return model; +} + +void llama_free_model(struct llama_model * model) { + delete model; +} + +struct llama_context * llama_new_context_with_model( + struct llama_model * model, + struct llama_context_params params) { + + if (!model) { + return nullptr; + } + + llama_context * ctx = new llama_context(*model); + + const auto & hparams = model->hparams; + auto & cparams = ctx->cparams; + + cparams.n_batch = params.n_batch; + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch; + cparams.yarn_ext_factor = params.yarn_ext_factor; + cparams.yarn_attn_factor = params.yarn_attn_factor; + cparams.yarn_beta_fast = params.yarn_beta_fast; + cparams.yarn_beta_slow = params.yarn_beta_slow; + cparams.mul_mat_q = params.mul_mat_q; + cparams.offload_kqv = params.offload_kqv; + + cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; + cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; + cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; + + cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : + hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx : + hparams.n_ctx_train; + + auto rope_scaling_type = params.rope_scaling_type; + if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) { + rope_scaling_type = hparams.rope_scaling_type_train; + } + + if (rope_scaling_type == LLAMA_ROPE_SCALING_NONE) { + cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none + } + + if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set' + cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_YARN ? 1.0f : 0.0f; + } + + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); + } + + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); + LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + + ctx->rng = std::mt19937(params.seed); + ctx->logits_all = params.logits_all; + + const ggml_type type_k = params.type_k; + const ggml_type type_v = params.type_v; + + GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_k) == 0); + GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_v) == 0); + + // reserve memory for context buffers + if (!hparams.vocab_only) { + if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, type_k, type_v, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) { + LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); + llama_free(ctx); + return nullptr; + } + + { + size_t memory_size_k = 0; + size_t memory_size_v = 0; + + for (auto & k : ctx->kv_self.k_l) { + memory_size_k += ggml_nbytes(k); + } + + for (auto & v : ctx->kv_self.v_l) { + memory_size_v += ggml_nbytes(v); + } + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } + + // resized during inference + if (params.logits_all) { + ctx->logits.reserve(cparams.n_ctx*hparams.n_vocab); + } else { + ctx->logits.reserve(hparams.n_vocab); + } + + if (params.embedding){ + ctx->embedding.resize(hparams.n_embd); + } + + { + static const size_t tensor_alignment = 32; + // the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data + ctx->buf_compute.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead()); + + // create measure allocator + ctx->alloc = ggml_allocr_new_measure(tensor_alignment); + + // build worst-case graph + int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch); + int n_past = cparams.n_ctx - n_tokens; + llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0)); + +#ifdef GGML_USE_METAL + if (model->n_gpu_layers > 0) { + ctx->ctx_metal = ggml_metal_init(1); + if (!ctx->ctx_metal) { + LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__); + llama_free(ctx); + return NULL; + } + //ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false); + //ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); + } +#endif + // measure memory requirements for the graph + size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment; + + LLAMA_LOG_INFO("%s: compute buffer total size = %.2f MiB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); + + // recreate allocator with exact memory requirements + ggml_allocr_free(ctx->alloc); + + ctx->buf_alloc.resize(alloc_size); + ctx->alloc = ggml_allocr_new(ctx->buf_alloc.data, ctx->buf_alloc.size, tensor_alignment); +#ifdef GGML_USE_METAL + if (ctx->ctx_metal) { + //ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); + } +#endif +#ifdef GGML_USE_CUBLAS + ggml_cuda_set_scratch_size(alloc_size); + LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MiB\n", __func__, alloc_size / 1024.0 / 1024.0); + + // calculate total VRAM usage + auto add_tensor = [](const ggml_tensor * t, size_t & size) { + if (t->backend == GGML_BACKEND_GPU || t->backend == GGML_BACKEND_GPU_SPLIT) { + size += ggml_nbytes(t); + } + }; + size_t model_vram_size = 0; + for (const auto & kv : model->tensors_by_name) { + add_tensor(kv.second, model_vram_size); + } + + size_t kv_vram_size = 0; + for (auto & k : ctx->kv_self.k_l) { + add_tensor(k, kv_vram_size); + } + for (auto & v : ctx->kv_self.v_l) { + add_tensor(v, kv_vram_size); + } + + size_t ctx_vram_size = alloc_size + kv_vram_size; + size_t total_vram_size = model_vram_size + ctx_vram_size; + + LLAMA_LOG_INFO("%s: total VRAM used: %.2f MiB (model: %.2f MiB, context: %.2f MiB)\n", __func__, + total_vram_size / 1024.0 / 1024.0, + model_vram_size / 1024.0 / 1024.0, + ctx_vram_size / 1024.0 / 1024.0); +#endif + } + +#ifdef GGML_USE_METAL + if (model->n_gpu_layers > 0) { + // this allocates all Metal resources and memory buffers + + void * data_ptr = NULL; + size_t data_size = 0; + + if (ctx->model.mapping) { + data_ptr = ctx->model.mapping->addr; + data_size = ctx->model.mapping->size; + } else { + data_ptr = ggml_get_mem_buffer(ctx->model.ctx); + data_size = ggml_get_mem_size (ctx->model.ctx); + } + + const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); + + LLAMA_LOG_INFO("%s: max tensor size = %8.2f MiB\n", __func__, max_size/1024.0/1024.0); + +#define LLAMA_METAL_CHECK_BUF(result) \ + if (!(result)) { \ + LLAMA_LOG_ERROR("%s: failed to add buffer\n", __func__); \ + llama_free(ctx); \ + return NULL; \ + } + + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size)); + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0)); + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.data, ctx->buf_alloc.size, 0)); +#undef LLAMA_METAL_CHECK_BUF + } +#endif + } + +#ifdef GGML_USE_MPI + ctx->ctx_mpi = ggml_mpi_init(); + + if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { + // Enter a blocking eval loop with dummy input, letting rank=0 drive the process + // TODO: needs fix after #3228 + GGML_ASSERT(false && "not implemented"); + //const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx)); + //while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; + llama_backend_free(); + exit(1); + } +#endif + + return ctx; +} + +void llama_free(struct llama_context * ctx) { + delete ctx; +} + +const llama_model * llama_get_model(const struct llama_context * ctx) { + return &ctx->model; +} + +int llama_n_ctx(const struct llama_context * ctx) { + return ctx->cparams.n_ctx; +} + +enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { + return model->vocab.type; +} + +int llama_n_vocab(const struct llama_model * model) { + return model->vocab.id_to_token.size(); +} + +int llama_n_ctx_train(const struct llama_model * model) { + return model->hparams.n_ctx_train; +} + +int llama_n_embd(const struct llama_model * model) { + return model->hparams.n_embd; +} + +float llama_rope_freq_scale_train(const struct llama_model * model) { + return model->hparams.rope_freq_scale_train; +} + +int llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size) { + const auto & it = model->gguf_kv.find(key); + if (it == model->gguf_kv.end()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + return snprintf(buf, buf_size, "%s", it->second.c_str()); +} + +int llama_model_meta_count(const struct llama_model * model) { + return (int)model->gguf_kv.size(); +} + +int llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) { + if (i < 0 || i >= (int)model->gguf_kv.size()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + auto it = model->gguf_kv.begin(); + std::advance(it, i); + return snprintf(buf, buf_size, "%s", it->first.c_str()); +} + +int llama_model_meta_val_str_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) { + if (i < 0 || i >= (int)model->gguf_kv.size()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + auto it = model->gguf_kv.begin(); + std::advance(it, i); + return snprintf(buf, buf_size, "%s", it->second.c_str()); +} + +int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) { + return snprintf(buf, buf_size, "%s %s %s", + llama_model_arch_name(model->arch).c_str(), + llama_model_type_name(model->type), + llama_model_ftype_name(model->ftype).c_str()); +} + +uint64_t llama_model_size(const struct llama_model * model) { + uint64_t size = 0; + for (const auto & it : model->tensors_by_name) { + size += ggml_nbytes(it.second); + } + return size; +} + +uint64_t llama_model_n_params(const struct llama_model * model) { + uint64_t nparams = 0; + for (const auto & it : model->tensors_by_name) { + nparams += ggml_nelements(it.second); + } + return nparams; +} + +struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) { + return ggml_get_tensor(model->ctx, name); +} + +int llama_model_quantize( + const char * fname_inp, + const char * fname_out, + const llama_model_quantize_params * params) { + try { + llama_model_quantize_internal(fname_inp, fname_out, params); + return 0; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what()); + return 1; + } +} + +int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, float scale, const char * path_base_model, int n_threads) { + try { + return llama_apply_lora_from_file_internal(ctx->model, path_lora, scale, path_base_model, n_threads); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); + return 1; + } +} + +int llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int n_threads) { + try { + return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); + return 1; + } +} + +struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) { + struct llama_kv_cache_view result = { + /*.n_cells = */ 0, + /*.n_max_seq = */ n_max_seq, + /*.token_count = */ 0, + /*.used_cells = */ llama_get_kv_cache_used_cells(ctx), + /*.max_contiguous = */ 0, + /*.max_contiguous_idx = */ -1, + /*.cells = */ nullptr, + /*.cells_sequences = */ nullptr, + }; + return result; +} + +void llama_kv_cache_view_free(struct llama_kv_cache_view * view) { + if (view->cells != nullptr) { + free(view->cells); + view->cells = nullptr; + } + if (view->cells_sequences != nullptr) { + free(view->cells_sequences); + view->cells_sequences = nullptr; + } +} + +void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) { + if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) { + view->n_cells = int32_t(ctx->kv_self.size); + void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells); + GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); + view->cells = (struct llama_kv_cache_view_cell *)p; + p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_max_seq * view->n_cells); + GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences"); + view->cells_sequences = (llama_seq_id *)p; + } + + const std::vector & kv_cells = ctx->kv_self.cells; + llama_kv_cache_view_cell * c_curr = view->cells; + llama_seq_id * cs_curr = view->cells_sequences; + int32_t used_cells = 0; + int32_t token_count = 0; + int32_t curr_contig_idx = -1; + uint32_t max_contig = 0; + int32_t max_contig_idx = -1; + + for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) { + const size_t curr_size = kv_cells[i].seq_id.size(); + token_count += curr_size; + c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; + + if (curr_size > 0) { + if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) { + max_contig = i - curr_contig_idx; + max_contig_idx = curr_contig_idx; + } + curr_contig_idx = -1; + } else if (curr_contig_idx < 0) { + curr_contig_idx = i; + } + + int seq_idx = 0; + for (const llama_seq_id it : kv_cells[i].seq_id) { + if (seq_idx >= view->n_max_seq) { + break; + } + cs_curr[seq_idx] = it; + seq_idx++; + } + if (seq_idx != 0) { + used_cells++; + } + for (; seq_idx < view->n_max_seq; seq_idx++) { + cs_curr[seq_idx] = -1; + } + } + if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) { + max_contig_idx = curr_contig_idx; + max_contig = kv_cells.size() - curr_contig_idx; + } + view->max_contiguous = max_contig; + view->max_contiguous_idx = max_contig_idx; + view->token_count = token_count; + view->used_cells = used_cells; + if (uint32_t(used_cells) != ctx->kv_self.used) { + LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n", + __func__, ctx->kv_self.used, used_cells); + } +} + +int llama_get_kv_cache_token_count(const struct llama_context * ctx) { + int result = 0; + + for (uint32_t i = 0; i < ctx->kv_self.size; i++) { + result += ctx->kv_self.cells[i].seq_id.size(); + } + + return result; +} + +int llama_get_kv_cache_used_cells(const struct llama_context * ctx) { + return ctx->kv_self.used; +} + +void llama_kv_cache_clear(struct llama_context * ctx) { + llama_kv_cache_clear(ctx->kv_self); +} + +void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); +} + +void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } + llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { + llama_kv_cache_seq_keep(ctx->kv_self, seq_id); +} + +void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta); +} + +// Returns the *maximum* size of the state +size_t llama_get_state_size(const struct llama_context * ctx) { + // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. + // for reference, std::mt19937(1337) serializes to 6701 bytes. + const size_t s_rng_size = sizeof(size_t); + const size_t s_rng = LLAMA_MAX_RNG_STATE; + const size_t s_logits_capacity = sizeof(size_t); + const size_t s_logits_size = sizeof(size_t); + const size_t s_logits = ctx->logits.capacity() * sizeof(float); + const size_t s_embedding_size = sizeof(size_t); + const size_t s_embedding = ctx->embedding.size() * sizeof(float); + const size_t s_kv_size = sizeof(size_t); + const size_t s_kv_ntok = sizeof(int); + const size_t s_kv = ctx->kv_self.buf.size; + + const size_t s_total = ( + + s_rng_size + + s_rng + + s_logits_capacity + + s_logits_size + + s_logits + + s_embedding_size + + s_embedding + + s_kv_size + + s_kv_ntok + + s_kv + ); + + return s_total; +} + +// llama_context_data +struct llama_data_context { + virtual void write(const void * src, size_t size) = 0; + virtual size_t get_size_written() = 0; + virtual ~llama_data_context() = default; +}; + +struct llama_data_buffer_context : llama_data_context { + uint8_t * ptr; + size_t size_written = 0; + + llama_data_buffer_context(uint8_t * p) : ptr(p) {} + + void write(const void * src, size_t size) override { + memcpy(ptr, src, size); + ptr += size; + size_written += size; + } + + size_t get_size_written() override { + return size_written; + } +}; + +struct llama_data_file_context : llama_data_context { + llama_file * file; + size_t size_written = 0; + + llama_data_file_context(llama_file * f) : file(f) {} + + void write(const void * src, size_t size) override { + file->write_raw(src, size); + size_written += size; + } + + size_t get_size_written() override { + return size_written; + } +}; + +/** copy state data into either a buffer or file depending on the passed in context + * + * file context: + * llama_file file("/path", "wb"); + * llama_data_file_context data_ctx(&file); + * llama_copy_state_data(ctx, &data_ctx); + * + * buffer context: + * std::vector buf(max_size, 0); + * llama_data_buffer_context data_ctx(&buf.data()); + * llama_copy_state_data(ctx, &data_ctx); + * +*/ +static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { + // copy rng + { + std::stringstream rng_ss; + rng_ss << ctx->rng; + + const size_t rng_size = rng_ss.str().size(); + char rng_buf[LLAMA_MAX_RNG_STATE]; + + memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); + memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); + + data_ctx->write(&rng_size, sizeof(rng_size)); + data_ctx->write(&rng_buf[0], LLAMA_MAX_RNG_STATE); + } + + // copy logits + { + const size_t logits_cap = ctx->logits.capacity(); + const size_t logits_size = ctx->logits.size(); + + data_ctx->write(&logits_cap, sizeof(logits_cap)); + data_ctx->write(&logits_size, sizeof(logits_size)); + + if (logits_size) { + data_ctx->write(ctx->logits.data(), logits_size * sizeof(float)); + } + + // If there is a gap between the size and the capacity, write padding + size_t padding_size = (logits_cap - logits_size) * sizeof(float); + if (padding_size > 0) { + std::vector padding(padding_size, 0); // Create a buffer filled with zeros + data_ctx->write(padding.data(), padding_size); + } + } + + // copy embeddings + { + const size_t embedding_size = ctx->embedding.size(); + + data_ctx->write(&embedding_size, sizeof(embedding_size)); + + if (embedding_size) { + data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float)); + } + } + + // copy kv cache + { + const auto & kv_self = ctx->kv_self; + const auto & hparams = ctx->model.hparams; + const auto & cparams = ctx->cparams; + + const auto n_layer = hparams.n_layer; + const auto n_embd = hparams.n_embd_gqa(); + const auto n_ctx = cparams.n_ctx; + + const size_t kv_buf_size = kv_self.buf.size; + const uint32_t kv_head = kv_self.head; + const uint32_t kv_size = kv_self.size; + const uint32_t kv_used = kv_self.used; + + data_ctx->write(&kv_buf_size, sizeof(kv_buf_size)); + data_ctx->write(&kv_head, sizeof(kv_head)); + data_ctx->write(&kv_size, sizeof(kv_size)); + data_ctx->write(&kv_used, sizeof(kv_used)); + + if (kv_buf_size) { + const size_t elt_size = ggml_element_size(kv_self.k_l[0]); + + ggml_context * cpy_ctx = ggml_init({ 6*n_layer*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true }); + ggml_cgraph * gf = ggml_new_graph(cpy_ctx); + + std::vector> kout2d_data(n_layer); + std::vector> vout2d_data(n_layer); + + for (int il = 0; il < (int) n_layer; ++il) { + ggml_tensor * kout2d = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head); + kout2d_data[il].resize(ggml_nbytes(kout2d)); + kout2d->data = kout2d_data[il].data(); + + ggml_tensor * vout2d = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd); + vout2d_data[il].resize(ggml_nbytes(vout2d)); + vout2d->data = vout2d_data[il].data(); + + ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il], + n_embd, kv_head, + elt_size*n_embd, 0); + + ggml_tensor * v2d = ggml_view_2d(cpy_ctx, kv_self.v_l[il], + kv_head, n_embd, + elt_size*n_ctx, 0); + + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k2d, kout2d)); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v2d, vout2d)); + } + + ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1); + + ggml_free(cpy_ctx); + + // our data is now in the kout2d_data and vout2d_data buffers + // write them to file + for (uint32_t il = 0; il < n_layer; ++il) { + data_ctx->write(kout2d_data[il].data(), kout2d_data[il].size()); + data_ctx->write(vout2d_data[il].data(), vout2d_data[il].size()); + } + } + + for (uint32_t i = 0; i < kv_size; ++i) { + const auto & cell = kv_self.cells[i]; + + const llama_pos pos = cell.pos; + const size_t seq_id_size = cell.seq_id.size(); + + data_ctx->write(&pos, sizeof(pos)); + data_ctx->write(&seq_id_size, sizeof(seq_id_size)); + + for (auto seq_id : cell.seq_id) { + data_ctx->write(&seq_id, sizeof(seq_id)); + } + } + } +} + +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { + llama_data_buffer_context data_ctx(dst); + llama_copy_state_data_internal(ctx, &data_ctx); + + return data_ctx.get_size_written(); +} + +// Sets the state reading from the specified source address +size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { + uint8_t * inp = src; + + // set rng + { + size_t rng_size; + char rng_buf[LLAMA_MAX_RNG_STATE]; + + memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size); + memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE; + + std::stringstream rng_ss; + rng_ss.str(std::string(&rng_buf[0], rng_size)); + rng_ss >> ctx->rng; + + GGML_ASSERT(!rng_ss.fail()); + } + + // set logits + { + size_t logits_cap; + size_t logits_size; + + memcpy(&logits_cap, inp, sizeof(logits_cap)); inp += sizeof(logits_cap); + memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size); + + GGML_ASSERT(ctx->logits.capacity() == logits_cap); + + if (logits_size) { + ctx->logits.resize(logits_size); + memcpy(ctx->logits.data(), inp, logits_size * sizeof(float)); + } + + inp += logits_cap * sizeof(float); + } + + // set embeddings + { + size_t embedding_size; + + memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size); + + GGML_ASSERT(ctx->embedding.capacity() == embedding_size); + + if (embedding_size) { + memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float)); + inp += embedding_size * sizeof(float); + } + } + + // set kv cache + { + const auto & kv_self = ctx->kv_self; + const auto & hparams = ctx->model.hparams; + const auto & cparams = ctx->cparams; + + const int n_layer = hparams.n_layer; + const int n_embd = hparams.n_embd_gqa(); + const int n_ctx = cparams.n_ctx; + + size_t kv_buf_size; + uint32_t kv_head; + uint32_t kv_size; + uint32_t kv_used; + + memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size); + memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head); + memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); + memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used); + + if (kv_buf_size) { + GGML_ASSERT(kv_self.buf.size == kv_buf_size); + + const size_t elt_size = ggml_element_size(kv_self.k_l[0]); + + ggml_context * cpy_ctx = ggml_init({ 6*n_layer*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true }); + ggml_cgraph * gf = ggml_new_graph(cpy_ctx); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * kin2d = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head); + kin2d->data = (void *) inp; + inp += ggml_nbytes(kin2d); + + ggml_tensor * vin2d = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd); + vin2d->data = (void *) inp; + inp += ggml_nbytes(vin2d); + + ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il], + n_embd, kv_head, + elt_size*n_embd, 0); + + ggml_tensor * v2d = ggml_view_2d(cpy_ctx, kv_self.v_l[il], + kv_head, n_embd, + elt_size*n_ctx, 0); + + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin2d, k2d)); + ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin2d, v2d)); + } + + ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1); + + ggml_free(cpy_ctx); + } + + ctx->kv_self.head = kv_head; + ctx->kv_self.size = kv_size; + ctx->kv_self.used = kv_used; + + ctx->kv_self.cells.resize(kv_size); + + for (uint32_t i = 0; i < kv_size; ++i) { + llama_pos pos; + size_t seq_id_size; + + memcpy(&pos, inp, sizeof(pos)); inp += sizeof(pos); + memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size); + + ctx->kv_self.cells[i].pos = pos; + + llama_seq_id seq_id; + + for (size_t j = 0; j < seq_id_size; ++j) { + memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id); + ctx->kv_self.cells[i].seq_id.insert(seq_id); + } + } + } + + const size_t nread = inp - src; + const size_t max_size = llama_get_state_size(ctx); + + GGML_ASSERT(nread <= max_size); + + return nread; +} + +static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(path_session, "rb"); + + // sanity checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { + LLAMA_LOG_ERROR("%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); + return false; + } + + llama_hparams session_hparams; + file.read_raw(&session_hparams, sizeof(llama_hparams)); + + if (session_hparams != ctx->model.hparams) { + LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__); + return false; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return false; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t n_state_size_cur = file.size - file.tell(); + const size_t n_state_size_max = llama_get_state_size(ctx); + + if (n_state_size_cur > n_state_size_max) { + LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur); + return false; + } + + std::vector state_data(n_state_size_max); + file.read_raw(state_data.data(), n_state_size_cur); + + llama_set_state_data(ctx, state_data.data()); + } + + return true; +} + +bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + try { + return llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("error loading session file: %s\n", err.what()); + return false; + } +} + +bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + llama_file file(path_session, "wb"); + + file.write_u32(LLAMA_SESSION_MAGIC); + file.write_u32(LLAMA_SESSION_VERSION); + + file.write_raw(&ctx->model.hparams, sizeof(llama_hparams)); + + // save the prompt + file.write_u32((uint32_t) n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_data_file_context data_ctx(&file); + llama_copy_state_data_internal(ctx, &data_ctx); + + return true; +} + +int llama_eval( + struct llama_context * ctx, + llama_token * tokens, + int32_t n_tokens, + int n_past) { + llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1); + + const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0)); + if (ret < 0) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); + } + + return ret; +} + +int llama_eval_embd( + struct llama_context * ctx, + float * embd, + int32_t n_tokens, + int n_past) { + llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1); + + llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, }; + + const int ret = llama_decode_internal(*ctx, batch); + if (ret < 0) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); + } + + return ret; +} + +void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) { + ctx->cparams.n_threads = n_threads; + ctx->cparams.n_threads_batch = n_threads_batch; +} + +struct llama_batch llama_batch_get_one( + llama_token * tokens, + int32_t n_tokens, + llama_pos pos_0, + llama_seq_id seq_id) { + return { + /*n_tokens =*/ n_tokens, + /*tokens =*/ tokens, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*all_pos_0 =*/ pos_0, + /*all_pos_1 =*/ 1, + /*all_seq_id =*/ seq_id, + }; +} + +struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) { + llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; + + if (embd) { + batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd); + } else { + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); + } + + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); + batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); + for (int i = 0; i < n_tokens; ++i) { + batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + } + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +void llama_batch_free(struct llama_batch batch) { + if (batch.token) free(batch.token); + if (batch.embd) free(batch.embd); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; i < batch.n_tokens; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} + +int llama_decode( + struct llama_context * ctx, + struct llama_batch batch) { + const int ret = llama_decode_internal(*ctx, batch); + if (ret < 0) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); + } + + return ret; +} + +float * llama_get_logits(struct llama_context * ctx) { + return ctx->logits.data(); +} + +float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { + assert(ctx->logits_valid.at(i)); + return ctx->logits.data() + i*ctx->model.hparams.n_vocab; +} + +float * llama_get_embeddings(struct llama_context * ctx) { + return ctx->embedding.data(); +} + +const char * llama_token_get_text(const struct llama_model * model, llama_token token) { + return model->vocab.id_to_token[token].text.c_str(); +} + +float llama_token_get_score(const struct llama_model * model, llama_token token) { + return model->vocab.id_to_token[token].score; +} + +llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) { + return model->vocab.id_to_token[token].type; +} + +llama_token llama_token_bos(const struct llama_model * model) { + return model->vocab.special_bos_id; +} + +llama_token llama_token_eos(const struct llama_model * model) { + return model->vocab.special_eos_id; +} + +llama_token llama_token_nl(const struct llama_model * model) { + return model->vocab.linefeed_id; +} + +int llama_add_bos_token(const struct llama_model * model) { + return model->vocab.special_add_bos; +} + +int llama_add_eos_token(const struct llama_model * model) { + return model->vocab.special_add_eos; +} + +llama_token llama_token_prefix(const struct llama_model * model) { + return model->vocab.special_prefix_id; +} + +llama_token llama_token_middle(const struct llama_model * model) { + return model->vocab.special_middle_id; +} + +llama_token llama_token_suffix(const struct llama_model * model) { + return model->vocab.special_suffix_id; +} + +llama_token llama_token_eot(const struct llama_model * model) { + return model->vocab.special_eot_id; +} + +int llama_tokenize( + const struct llama_model * model, + const char * text, + int text_len, + llama_token * tokens, + int n_max_tokens, + bool add_bos, + bool special) { + auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special); + + if (n_max_tokens < (int) res.size()) { + // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); + return -((int) res.size()); + } + + for (size_t i = 0; i < res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +static std::string llama_decode_text(const std::string & text) { + std::string decoded_text; + auto unicode_sequences = codepoints_from_utf8(text); + for (auto& unicode_sequence : unicode_sequences) { + decoded_text += unicode_to_bytes_bpe(codepoint_to_utf8(unicode_sequence)); + } + + return decoded_text; +} + +// does not write null-terminator to buf +int llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int length) { + if (0 <= token && token < llama_n_vocab(model)) { + switch (llama_vocab_get_type(model->vocab)) { + case LLAMA_VOCAB_TYPE_SPM: { + if (llama_is_normal_token(model->vocab, token)) { + std::string result = model->vocab.id_to_token[token].text; + llama_unescape_whitespace(result); + if (length < (int) result.length()) { + return -result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); + } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT + if (length < 3) { + return -3; + } + memcpy(buf, "\xe2\x96\x85", 3); + return 3; + } else if (llama_is_control_token(model->vocab, token)) { + ; + } else if (llama_is_byte_token(model->vocab, token)) { + if (length < 1) { + return -1; + } + buf[0] = llama_token_to_byte(model->vocab, token); + return 1; + } else { + // TODO: for now we accept all unsupported token types, + // suppressing them like CONTROL tokens. + // GGML_ASSERT(false); + } + break; + } + case LLAMA_VOCAB_TYPE_BPE: { + if (llama_is_normal_token(model->vocab, token)) { + std::string result = model->vocab.id_to_token[token].text; + result = llama_decode_text(result); + if (length < (int) result.length()) { + return -result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); + } else if (llama_is_control_token(model->vocab, token)) { + ; + } else { + // TODO: for now we accept all unsupported token types, + // suppressing them like CONTROL tokens. + // GGML_ASSERT(false); + } + break; + } + default: + GGML_ASSERT(false); + } + } + return 0; +} + +struct llama_timings llama_get_timings(struct llama_context * ctx) { + struct llama_timings result = { + /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, + /*.t_end_ms =*/ 1.00 * ggml_time_ms(), + /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, + /*.t_sample_ms =*/ 1e-3 * ctx->t_sample_us, + /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, + /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, + + /*.n_sample =*/ std::max(1, ctx->n_sample), + /*.n_p_eval =*/ std::max(1, ctx->n_p_eval), + /*.n_eval =*/ std::max(1, ctx->n_eval), + }; + + return result; +} + +void llama_print_timings(struct llama_context * ctx) { + const llama_timings timings = llama_get_timings(ctx); + + LLAMA_LOG_INFO("\n"); + LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms); + LLAMA_LOG_INFO("%s: sample time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample); + LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); + LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval); + LLAMA_LOG_INFO("%s: total time = %10.2f ms\n", __func__, (timings.t_end_ms - timings.t_start_ms)); +} + +void llama_reset_timings(struct llama_context * ctx) { + ctx->t_start_us = ggml_time_us(); + ctx->t_sample_us = ctx->n_sample = 0; + ctx->t_eval_us = ctx->n_eval = 0; + ctx->t_p_eval_us = ctx->n_p_eval = 0; +} + +const char * llama_print_system_info(void) { + static std::string s; + + s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | "; + s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; + s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | "; + s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; + + return s.c_str(); +} + +void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) { + fprintf(stream, "\n"); + fprintf(stream, "###########\n"); + fprintf(stream, "# Timings #\n"); + fprintf(stream, "###########\n"); + fprintf(stream, "\n"); + + fprintf(stream, "mst_eval: %.2f # ms / token during generation\n", + 1.0e-3 * ctx->t_eval_us / ctx->n_eval); + fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n", + 1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval); + fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n", + 1.0e-3 * ctx->t_sample_us / ctx->n_sample); + fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval); + fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval); + fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->n_sample); + fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us); + fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us); + fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us); + fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->t_sample_us); + fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n", + 1.0e6 * ctx->n_eval / ctx->t_eval_us); + fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n", + 1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us); + fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n", + 1.0e6 * ctx->n_sample / ctx->t_sample_us); +} + +// For internal test use +const std::vector> & llama_internal_get_tensor_map( + struct llama_context * ctx +) { + return ctx->model.tensors_by_name; +} + +void llama_log_set(ggml_log_callback log_callback, void * user_data) { + g_state.log_callback = log_callback ? log_callback : llama_log_callback_default; + g_state.log_callback_user_data = user_data; +#ifdef GGML_USE_METAL + ggml_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data); +#endif +} + +static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) { + va_list args_copy; + va_copy(args_copy, args); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + g_state.log_callback(level, buffer, g_state.log_callback_user_data); + } else { + char* buffer2 = new char[len+1]; + vsnprintf(buffer2, len+1, format, args_copy); + buffer2[len] = 0; + g_state.log_callback(level, buffer2, g_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args_copy); +} + +static void llama_log_internal(ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + llama_log_internal_v(level, format, args); + va_end(args); +} + +static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} diff --git a/applications/src/llama/llama.h b/applications/src/llama/llama.h new file mode 100644 index 000000000..15ab4f80e --- /dev/null +++ b/applications/src/llama/llama.h @@ -0,0 +1,871 @@ +#ifndef LLAMA_H +#define LLAMA_H + +#include "ggml.h" +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES +#else +#define LLAMA_MAX_DEVICES 1 +#endif // GGML_USE_CUBLAS +#include +#include +#include +#include + +#ifdef LLAMA_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef LLAMA_BUILD +# define LLAMA_API __declspec(dllexport) +# else +# define LLAMA_API __declspec(dllimport) +# endif +# else +# define LLAMA_API __attribute__ ((visibility ("default"))) +# endif +#else +# define LLAMA_API +#endif + +#ifdef __GNUC__ +# define DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define DEPRECATED(func, hint) func +#endif + +#define LLAMA_DEFAULT_SEED 0xFFFFFFFF + +#define LLAMA_MAX_RNG_STATE (64*1024) + +#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' +#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' + +#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN +#define LLAMA_SESSION_VERSION 3 + +#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) +// Defined when llama.cpp is compiled with support for offloading model layers to GPU. +#define LLAMA_SUPPORTS_GPU_OFFLOAD +#endif + +#ifdef __cplusplus +extern "C" { +#endif + + // + // C interface + // + // TODO: show sample usage + // + + struct llama_model; + struct llama_context; + + typedef int32_t llama_pos; + typedef int32_t llama_token; + typedef int32_t llama_seq_id; + + enum llama_vocab_type { + LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece + LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding + }; + + enum llama_token_type { + LLAMA_TOKEN_TYPE_UNDEFINED = 0, + LLAMA_TOKEN_TYPE_NORMAL = 1, + LLAMA_TOKEN_TYPE_UNKNOWN = 2, + LLAMA_TOKEN_TYPE_CONTROL = 3, + LLAMA_TOKEN_TYPE_USER_DEFINED = 4, + LLAMA_TOKEN_TYPE_UNUSED = 5, + LLAMA_TOKEN_TYPE_BYTE = 6, + }; + + // model file types + enum llama_ftype { + LLAMA_FTYPE_ALL_F32 = 0, + LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 + // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed + // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed + LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors + + LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file + }; + + enum llama_rope_scaling_type { + LLAMA_ROPE_SCALING_UNSPECIFIED = -1, + LLAMA_ROPE_SCALING_NONE = 0, + LLAMA_ROPE_SCALING_LINEAR = 1, + LLAMA_ROPE_SCALING_YARN = 2, + LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN, + }; + + typedef struct llama_token_data { + llama_token id; // token id + float logit; // log-odds of the token + float p; // probability of the token + } llama_token_data; + + typedef struct llama_token_data_array { + llama_token_data * data; + size_t size; + bool sorted; + } llama_token_data_array; + + typedef void (*llama_progress_callback)(float progress, void *ctx); + + // Input data for llama_decode + // A llama_batch object can contain input about one or many sequences + // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens + // + // - token : the token ids of the input (used when embd is NULL) + // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) + // - pos : the positions of the respective token in the sequence + // - seq_id : the sequence to which the respective token belongs + // - logits : if zero, the logits for the respective token will not be output + // + typedef struct llama_batch { + int32_t n_tokens; + + llama_token * token; + float * embd; + llama_pos * pos; + int32_t * n_seq_id; + llama_seq_id ** seq_id; + int8_t * logits; + + // NOTE: helpers for smooth API transition - can be deprecated in the future + // for future-proof code, use the above fields instead and ignore everything below + // + // pos[i] = all_pos_0 + i*all_pos_1 + // + llama_pos all_pos_0; // used if pos == NULL + llama_pos all_pos_1; // used if pos == NULL + llama_seq_id all_seq_id; // used if seq_id == NULL + } llama_batch; + + enum llama_model_kv_override_type { + LLAMA_KV_OVERRIDE_INT, + LLAMA_KV_OVERRIDE_FLOAT, + LLAMA_KV_OVERRIDE_BOOL, + }; + + struct llama_model_kv_override { + char key[128]; + enum llama_model_kv_override_type tag; + union { + int64_t int_value; + double float_value; + bool bool_value; + }; + }; + + struct llama_model_params { + int32_t n_gpu_layers; // number of layers to store in VRAM + int32_t main_gpu; // the GPU that is used for scratch and small tensors + const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) + + // called with a progress value between 0 and 1, pass NULL to disable + llama_progress_callback progress_callback; + + // context pointer passed to the progress callback + void * progress_callback_user_data; + + // override key-value pairs of the model meta data + const struct llama_model_kv_override * kv_overrides; + + // Keep the booleans together to avoid misalignment during copy-by-value. + bool vocab_only; // only load the vocabulary, no weights + bool use_mmap; // use mmap if possible + bool use_mlock; // force system to keep model in RAM + }; + + struct llama_context_params { + uint32_t seed; // RNG seed, -1 for random + uint32_t n_ctx; // text context, 0 = from model + uint32_t n_batch; // prompt processing maximum batch size + uint32_t n_threads; // number of threads to use for generation + uint32_t n_threads_batch; // number of threads to use for batch processing + int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` + + // ref: https://github.com/ggerganov/llama.cpp/pull/2054 + float rope_freq_base; // RoPE base frequency, 0 = from model + float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model + float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model + float yarn_attn_factor; // YaRN magnitude scaling factor + float yarn_beta_fast; // YaRN low correction dim + float yarn_beta_slow; // YaRN high correction dim + uint32_t yarn_orig_ctx; // YaRN original context size + + enum ggml_type type_k; // data type for K cache + enum ggml_type type_v; // data type for V cache + + // Keep the booleans together to avoid misalignment during copy-by-value. + bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true) + bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) + bool embedding; // embedding mode only + bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU + }; + + // model quantization parameters + typedef struct llama_model_quantize_params { + int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() + enum llama_ftype ftype; // quantize to this llama_ftype + bool allow_requantize; // allow quantizing non-f32/f16 tensors + bool quantize_output_tensor; // quantize output.weight + bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored + bool pure; // disable k-quant mixtures and quantize all tensors to the same type + } llama_model_quantize_params; + + // grammar types + struct llama_grammar; + + // grammar element type + enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 6, + }; + + typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point or rule ID + } llama_grammar_element; + + // performance timing information + struct llama_timings { + double t_start_ms; + double t_end_ms; + double t_load_ms; + double t_sample_ms; + double t_p_eval_ms; + double t_eval_ms; + + int32_t n_sample; + int32_t n_p_eval; + int32_t n_eval; + }; + + // Helpers for getting default parameters + LLAMA_API struct llama_model_params llama_model_default_params(void); + LLAMA_API struct llama_context_params llama_context_default_params(void); + LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); + + // Initialize the llama + ggml backend + // If numa is true, use NUMA optimizations + // Call once at the start of the program + LLAMA_API void llama_backend_init(bool numa); + + // Call once at the end of the program - currently only used for MPI + LLAMA_API void llama_backend_free(void); + + LLAMA_API struct llama_model * llama_load_model_from_file( + const char * path_model, + struct llama_model_params params); + + LLAMA_API void llama_free_model(struct llama_model * model); + + LLAMA_API struct llama_context * llama_new_context_with_model( + struct llama_model * model, + struct llama_context_params params); + + // Frees all allocated memory + LLAMA_API void llama_free(struct llama_context * ctx); + + LLAMA_API int64_t llama_time_us(void); + + LLAMA_API int llama_max_devices (void); + LLAMA_API bool llama_mmap_supported (void); + LLAMA_API bool llama_mlock_supported(void); + + LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); + + LLAMA_API int llama_n_ctx (const struct llama_context * ctx); + + LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); + + LLAMA_API int llama_n_vocab (const struct llama_model * model); + LLAMA_API int llama_n_ctx_train(const struct llama_model * model); + LLAMA_API int llama_n_embd (const struct llama_model * model); + + // Get the model's RoPE frequency scaling factor + LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); + + // Functions to access the model's GGUF metadata scalar values + // - The functions return the length of the string on success, or -1 on failure + // - The output string is always null-terminated and cleared on failure + // - GGUF array values are not supported by these functions + + // Get metadata value as a string by key name + LLAMA_API int llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size); + + // Get the number of metadata key/value pairs + LLAMA_API int llama_model_meta_count(const struct llama_model * model); + + // Get metadata key name by index + LLAMA_API int llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size); + + // Get metadata value as a string by index + LLAMA_API int llama_model_meta_val_str_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size); + + // Get a string describing the model type + LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); + + // Returns the total size of all the tensors in the model in bytes + LLAMA_API uint64_t llama_model_size(const struct llama_model * model); + + // Returns the total number of parameters in the model + LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); + + // Get a llama model tensor + LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); + + // Returns 0 on success + LLAMA_API int llama_model_quantize( + const char * fname_inp, + const char * fname_out, + const llama_model_quantize_params * params); + + // Apply a LoRA adapter to a loaded model + // path_base_model is the path to a higher quality model to use as a base for + // the layers modified by the adapter. Can be NULL to use the current loaded model. + // The model needs to be reloaded before applying a new adapter, otherwise the adapter + // will be applied on top of the previous one + // Returns 0 on success + LLAMA_API DEPRECATED(int llama_apply_lora_from_file( + struct llama_context * ctx, + const char * path_lora, + float scale, + const char * path_base_model, + int n_threads), + "use llama_model_apply_lora_from_file instead"); + + LLAMA_API int llama_model_apply_lora_from_file( + const struct llama_model * model, + const char * path_lora, + float scale, + const char * path_base_model, + int n_threads); + + // + // KV cache + // + + // Information associated with an individual cell in the KV cache view. + struct llama_kv_cache_view_cell { + // The position for this cell. Takes KV cache shifts into account. + // May be negative if the cell is not populated. + llama_pos pos; + }; + + // An updateable view of the KV cache. + struct llama_kv_cache_view { + // Number of KV cache cells. This will be the same as the context size. + int32_t n_cells; + + // Maximum number of sequences that can exist in a cell. It's not an error + // if there are more sequences in a cell than this value, however they will + // not be visible in the view cells_sequences. + int32_t n_max_seq; + + // Number of tokens in the cache. For example, if there are two populated + // cells, the first with 1 sequence id in it and the second with 2 sequence + // ids then you'll have 3 tokens. + int32_t token_count; + + // Number of populated cache cells. + int32_t used_cells; + + // Maximum contiguous empty slots in the cache. + int32_t max_contiguous; + + // Index to the start of the max_contiguous slot range. Can be negative + // when cache is full. + int32_t max_contiguous_idx; + + // Information for an individual cell. + struct llama_kv_cache_view_cell * cells; + + // The sequences for each cell. There will be n_max_seq items per cell. + llama_seq_id * cells_sequences; + }; + + // Create an empty KV cache view. (use only for debugging purposes) + LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq); + + // Free a KV cache view. (use only for debugging purposes) + LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view); + + // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) + LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); + + // Returns the number of tokens in the KV cache (slow, use only for debug) + // If a KV cell has multiple sequences assigned to it, it will be counted multiple times + LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx); + + // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) + LLAMA_API int llama_get_kv_cache_used_cells(const struct llama_context * ctx); + + // Clear the KV cache + LLAMA_API void llama_kv_cache_clear( + struct llama_context * ctx); + + // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + // seq_id < 0 : match any sequence + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + LLAMA_API void llama_kv_cache_seq_rm( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1); + + // Copy all tokens that belong to the specified sequence to another sequence + // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + LLAMA_API void llama_kv_cache_seq_cp( + struct llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1); + + // Removes all tokens that do not belong to the specified sequence + LLAMA_API void llama_kv_cache_seq_keep( + struct llama_context * ctx, + llama_seq_id seq_id); + + // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) + // If the KV cache is RoPEd, the KV data is updated accordingly + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + LLAMA_API void llama_kv_cache_seq_shift( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta); + + // + // State / sessions + // + + // Returns the maximum size in bytes of the state (rng, logits, embedding + // and kv_cache) - will often be smaller after compacting tokens + LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); + + // Copies the state to the specified destination address. + // Destination needs to have allocated enough memory. + // Returns the number of bytes copied + LLAMA_API size_t llama_copy_state_data( + struct llama_context * ctx, + uint8_t * dst); + + // Set the state reading from the specified address + // Returns the number of bytes read + LLAMA_API size_t llama_set_state_data( + struct llama_context * ctx, + uint8_t * src); + + // Save/load session file + LLAMA_API bool llama_load_session_file( + struct llama_context * ctx, + const char * path_session, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out); + + LLAMA_API bool llama_save_session_file( + struct llama_context * ctx, + const char * path_session, + const llama_token * tokens, + size_t n_token_count); + + // + // Decoding + // + + // Run the llama inference to obtain the logits and probabilities for the next token(s). + // tokens + n_tokens is the provided batch of new tokens to process + // n_past is the number of tokens to use from previous eval calls + // Returns 0 on success + // DEPRECATED: use llama_decode() instead + LLAMA_API DEPRECATED(int llama_eval( + struct llama_context * ctx, + llama_token * tokens, + int32_t n_tokens, + int n_past), + "use llama_decode() instead"); + + // Same as llama_eval, but use float matrix input directly. + // DEPRECATED: use llama_decode() instead + LLAMA_API DEPRECATED(int llama_eval_embd( + struct llama_context * ctx, + float * embd, + int32_t n_tokens, + int n_past), + "use llama_decode() instead"); + + // Return batch for single sequence of tokens starting at pos_0 + // + // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it + // + LLAMA_API struct llama_batch llama_batch_get_one( + llama_token * tokens, + int32_t n_tokens, + llama_pos pos_0, + llama_seq_id seq_id); + + // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens + // Each token can be assigned up to n_seq_max sequence ids + // The batch has to be freed with llama_batch_free() + // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) + // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token + // The rest of the llama_batch members are allocated with size n_tokens + // All members are left uninitialized + LLAMA_API struct llama_batch llama_batch_init( + int32_t n_tokens, + int32_t embd, + int32_t n_seq_max); + + // Frees a batch of tokens allocated with llama_batch_init() + LLAMA_API void llama_batch_free(struct llama_batch batch); + + // Positive return values does not mean a fatal error, but rather a warning. + // 0 - success + // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) + // < 0 - error + LLAMA_API int llama_decode( + struct llama_context * ctx, + struct llama_batch batch); + + // Set the number of threads used for decoding + // n_threads is the number of threads used for generation (single token) + // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) + LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); + + // Token logits obtained from the last call to llama_eval() + // The logits for the last token are stored in the last row + // Logits for which llama_batch.logits[i] == 0 are undefined + // Rows: n_tokens provided with llama_batch + // Cols: n_vocab + LLAMA_API float * llama_get_logits(struct llama_context * ctx); + + // Logits for the ith token. Equivalent to: + // llama_get_logits(ctx) + i*n_vocab + LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); + + // Get the embeddings for the input + // shape: [n_embd] (1-dimensional) + LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); + + // + // Vocab + // + + LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token); + + LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); + + LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); + + // Special tokens + LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence + LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence + LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line + + // Returns -1 if unknown, 1 for true or 0 for false. + LLAMA_API int llama_add_bos_token(const struct llama_model * model); + + // Returns -1 if unknown, 1 for true or 0 for false. + LLAMA_API int llama_add_eos_token(const struct llama_model * model); + + // codellama infill tokens + LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix + LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle + LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix + LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle + + // + // Tokenization + // + + /// @details Convert the provided text into tokens. + /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. + /// @return Returns the number of tokens on success, no more than n_max_tokens + /// @return Returns a negative number on failure - the number of tokens that would have been returned + /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. + /// Does not insert a leading space. + LLAMA_API int llama_tokenize( + const struct llama_model * model, + const char * text, + int text_len, + llama_token * tokens, + int n_max_tokens, + bool add_bos, + bool special); + + // Token Id -> Piece. + // Uses the vocabulary in the provided context. + // Does not write null terminator to the buffer. + // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens. + LLAMA_API int llama_token_to_piece( + const struct llama_model * model, + llama_token token, + char * buf, + int length); + + // + // Grammar + // + + LLAMA_API struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + + LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); + + LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar); + + // + // Sampling functions + // + + // Sets the current rng seed. + LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); + + /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. + /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. + LLAMA_API void llama_sample_repetition_penalties( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present); + + /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. + /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. + /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. + LLAMA_API void llama_sample_classifier_free_guidance( + struct llama_context * ctx, + llama_token_data_array * candidates, + struct llama_context * guidance_ctx, + float scale); + + /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. + LLAMA_API void llama_sample_softmax( + struct llama_context * ctx, + llama_token_data_array * candidates); + + /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + LLAMA_API void llama_sample_top_k( + struct llama_context * ctx, + llama_token_data_array * candidates, + int k, + size_t min_keep); + + /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + LLAMA_API void llama_sample_top_p( + struct llama_context * ctx, + llama_token_data_array * candidates, + float p, + size_t min_keep); + + /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 + LLAMA_API void llama_sample_min_p( + struct llama_context * ctx, + llama_token_data_array * candidates, + float p, + size_t min_keep); + + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. + LLAMA_API void llama_sample_tail_free( + struct llama_context * ctx, + llama_token_data_array * candidates, + float z, + size_t min_keep); + + /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. + LLAMA_API void llama_sample_typical( + struct llama_context * ctx, + llama_token_data_array * candidates, + float p, + size_t min_keep); + + LLAMA_API void llama_sample_temp( + struct llama_context * ctx, + llama_token_data_array * candidates, + float temp); + + LLAMA_API DEPRECATED(void llama_sample_temperature( + struct llama_context * ctx, + llama_token_data_array * candidates, + float temp), + "use llama_sample_temp instead"); + + /// @details Apply constraints from grammar + LLAMA_API void llama_sample_grammar( + struct llama_context * ctx, + llama_token_data_array * candidates, + const struct llama_grammar * grammar); + + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + LLAMA_API llama_token llama_sample_token_mirostat( + struct llama_context * ctx, + llama_token_data_array * candidates, + float tau, + float eta, + int m, + float * mu); + + /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + LLAMA_API llama_token llama_sample_token_mirostat_v2( + struct llama_context * ctx, + llama_token_data_array * candidates, + float tau, + float eta, + float * mu); + + /// @details Selects the token with the highest probability. + /// Does not compute the token probabilities. Use llama_sample_softmax() instead. + LLAMA_API llama_token llama_sample_token_greedy( + struct llama_context * ctx, + llama_token_data_array * candidates); + + /// @details Randomly selects a token from the candidates based on their probabilities. + LLAMA_API llama_token llama_sample_token( + struct llama_context * ctx, + llama_token_data_array * candidates); + + /// @details Accepts the sampled token into the grammar + LLAMA_API void llama_grammar_accept_token( + struct llama_context * ctx, + struct llama_grammar * grammar, + llama_token token); + + // + // Beam search + // + + struct llama_beam_view { + const llama_token * tokens; + + size_t n_tokens; + float p; // Cumulative beam probability (renormalized relative to all beams) + bool eob; // Callback should set this to true when a beam is at end-of-beam. + }; + + // Passed to beam_search_callback function. + // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams + // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. + // These pointers are valid only during the synchronous callback, so should not be saved. + struct llama_beams_state { + struct llama_beam_view * beam_views; + + size_t n_beams; // Number of elements in beam_views[]. + size_t common_prefix_length; // Current max length of prefix tokens shared by all beams. + bool last_call; // True iff this is the last callback invocation. + }; + + // Type of pointer to the beam_search_callback function. + // void* callback_data is any custom data passed to llama_beam_search, that is subsequently + // passed back to beam_search_callback. This avoids having to use global variables in the callback. + typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state); + + /// @details Deterministically returns entire sentence constructed by a beam search. + /// @param ctx Pointer to the llama_context. + /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state. + /// @param callback_data A pointer that is simply passed back to callback. + /// @param n_beams Number of beams to use. + /// @param n_past Number of tokens already evaluated. + /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. + LLAMA_API void llama_beam_search( + struct llama_context * ctx, + llama_beam_search_callback_fn_t callback, + void * callback_data, + size_t n_beams, + int n_past, + int n_predict); + + // Performance information + LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); + + LLAMA_API void llama_print_timings(struct llama_context * ctx); + LLAMA_API void llama_reset_timings(struct llama_context * ctx); + + // Print system information + LLAMA_API const char * llama_print_system_info(void); + + // Set callback for all future logging events. + // If this is not called, or NULL is supplied, everything is output on stderr. + LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); + + LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx); + +#ifdef __cplusplus +} +#endif + +// Internal API to be implemented by llama.cpp and used by tests/benchmarks only +#ifdef LLAMA_API_INTERNAL + +#include +#include + +struct ggml_tensor; + +const std::vector> & llama_internal_get_tensor_map( + struct llama_context * ctx +); + +#endif // LLAMA_API_INTERNAL + +#endif // LLAMA_H diff --git a/applications/src/llama/unicode.h b/applications/src/llama/unicode.h new file mode 100644 index 000000000..aeca879ea --- /dev/null +++ b/applications/src/llama/unicode.h @@ -0,0 +1,462 @@ +#pragma once + +#include +#include +#include +#include + +static const std::vector> digit_ranges = { +{0x30, 0x39}, {0xB2, 0xB3}, {0xB9, 0xB9}, {0x660, 0x669}, {0x6F0, 0x6F9}, {0x7C0, 0x7C9}, {0x966, 0x96F}, {0x9E6, 0x9EF}, {0xA66, 0xA6F}, {0xAE6, 0xAEF}, {0xB66, 0xB6F}, {0xBE6, 0xBEF}, {0xC66, 0xC6F}, +{0xCE6, 0xCEF}, {0xD66, 0xD6F}, {0xDE6, 0xDEF}, {0xE50, 0xE59}, {0xED0, 0xED9}, {0xF20, 0xF29}, {0x1040, 0x1049}, {0x1090, 0x1099}, {0x1369, 0x1371}, {0x17E0, 0x17E9}, {0x1810, 0x1819}, {0x1946, 0x194F}, +{0x19D0, 0x19DA}, {0x1A80, 0x1A89}, {0x1A90, 0x1A99}, {0x1B50, 0x1B59}, {0x1BB0, 0x1BB9}, {0x1C40, 0x1C49}, {0x1C50, 0x1C59}, {0x2070, 0x2070}, {0x2074, 0x2079}, {0x2080, 0x2089}, {0x2460, 0x2468}, +{0x2474, 0x247C}, {0x2488, 0x2490}, {0x24EA, 0x24EA}, {0x24F5, 0x24FD}, {0x24FF, 0x24FF}, {0x2776, 0x277E}, {0x2780, 0x2788}, {0x278A, 0x2792}, {0xA620, 0xA629}, {0xA8D0, 0xA8D9}, {0xA900, 0xA909}, +{0xA9D0, 0xA9D9}, {0xA9F0, 0xA9F9}, {0xAA50, 0xAA59}, {0xABF0, 0xABF9}, {0xFF10, 0xFF19}, {0x104A0, 0x104A9}, {0x10A40, 0x10A43}, {0x10D30, 0x10D39}, {0x10E60, 0x10E68}, {0x11052, 0x1105A}, +{0x11066, 0x1106F}, {0x110F0, 0x110F9}, {0x11136, 0x1113F}, {0x111D0, 0x111D9}, {0x112F0, 0x112F9}, {0x11450, 0x11459}, {0x114D0, 0x114D9}, {0x11650, 0x11659}, {0x116C0, 0x116C9}, {0x11730, 0x11739}, +{0x118E0, 0x118E9}, {0x11950, 0x11959}, {0x11C50, 0x11C59}, {0x11D50, 0x11D59}, {0x11DA0, 0x11DA9}, {0x16A60, 0x16A69}, {0x16B50, 0x16B59}, {0x1D7CE, 0x1D7FF}, {0x1E140, 0x1E149}, {0x1E2F0, 0x1E2F9}, +{0x1E950, 0x1E959}, {0x1F100, 0x1F10A}, {0x1FBF0, 0x1FBF9}, +}; + +static const std::vector> letter_ranges = { +{0x41, 0x5A}, {0x61, 0x7A}, {0xAA, 0xAA}, {0xB5, 0xB5}, {0xBA, 0xBA}, {0xC0, 0xD6}, {0xD8, 0xF6}, {0xF8, 0x2C1}, {0x2C6, 0x2D1}, {0x2E0, 0x2E4}, {0x2EC, 0x2EC}, {0x2EE, 0x2EE}, {0x370, 0x374}, +{0x376, 0x377}, {0x37A, 0x37D}, {0x37F, 0x37F}, {0x386, 0x386}, {0x388, 0x38A}, {0x38C, 0x38C}, {0x38E, 0x3A1}, {0x3A3, 0x3F5}, {0x3F7, 0x481}, {0x48A, 0x52F}, {0x531, 0x556}, {0x559, 0x559}, +{0x560, 0x588}, {0x5D0, 0x5EA}, {0x5EF, 0x5F2}, {0x620, 0x64A}, {0x66E, 0x66F}, {0x671, 0x6D3}, {0x6D5, 0x6D5}, {0x6E5, 0x6E6}, {0x6EE, 0x6EF}, {0x6FA, 0x6FC}, {0x6FF, 0x6FF}, {0x710, 0x710}, +{0x712, 0x72F}, {0x74D, 0x7A5}, {0x7B1, 0x7B1}, {0x7CA, 0x7EA}, {0x7F4, 0x7F5}, {0x7FA, 0x7FA}, {0x800, 0x815}, {0x81A, 0x81A}, {0x824, 0x824}, {0x828, 0x828}, {0x840, 0x858}, {0x860, 0x86A}, +{0x8A0, 0x8B4}, {0x8B6, 0x8C7}, {0x904, 0x939}, {0x93D, 0x93D}, {0x950, 0x950}, {0x958, 0x961}, {0x971, 0x980}, {0x985, 0x98C}, {0x98F, 0x990}, {0x993, 0x9A8}, {0x9AA, 0x9B0}, {0x9B2, 0x9B2}, +{0x9B6, 0x9B9}, {0x9BD, 0x9BD}, {0x9CE, 0x9CE}, {0x9DC, 0x9DD}, {0x9DF, 0x9E1}, {0x9F0, 0x9F1}, {0x9FC, 0x9FC}, {0xA05, 0xA0A}, {0xA0F, 0xA10}, {0xA13, 0xA28}, {0xA2A, 0xA30}, {0xA32, 0xA33}, +{0xA35, 0xA36}, {0xA38, 0xA39}, {0xA59, 0xA5C}, {0xA5E, 0xA5E}, {0xA72, 0xA74}, {0xA85, 0xA8D}, {0xA8F, 0xA91}, {0xA93, 0xAA8}, {0xAAA, 0xAB0}, {0xAB2, 0xAB3}, {0xAB5, 0xAB9}, {0xABD, 0xABD}, +{0xAD0, 0xAD0}, {0xAE0, 0xAE1}, {0xAF9, 0xAF9}, {0xB05, 0xB0C}, {0xB0F, 0xB10}, {0xB13, 0xB28}, {0xB2A, 0xB30}, {0xB32, 0xB33}, {0xB35, 0xB39}, {0xB3D, 0xB3D}, {0xB5C, 0xB5D}, {0xB5F, 0xB61}, +{0xB71, 0xB71}, {0xB83, 0xB83}, {0xB85, 0xB8A}, {0xB8E, 0xB90}, {0xB92, 0xB95}, {0xB99, 0xB9A}, {0xB9C, 0xB9C}, {0xB9E, 0xB9F}, {0xBA3, 0xBA4}, {0xBA8, 0xBAA}, {0xBAE, 0xBB9}, {0xBD0, 0xBD0}, +{0xC05, 0xC0C}, {0xC0E, 0xC10}, {0xC12, 0xC28}, {0xC2A, 0xC39}, {0xC3D, 0xC3D}, {0xC58, 0xC5A}, {0xC60, 0xC61}, {0xC80, 0xC80}, {0xC85, 0xC8C}, {0xC8E, 0xC90}, {0xC92, 0xCA8}, {0xCAA, 0xCB3}, +{0xCB5, 0xCB9}, {0xCBD, 0xCBD}, {0xCDE, 0xCDE}, {0xCE0, 0xCE1}, {0xCF1, 0xCF2}, {0xD04, 0xD0C}, {0xD0E, 0xD10}, {0xD12, 0xD3A}, {0xD3D, 0xD3D}, {0xD4E, 0xD4E}, {0xD54, 0xD56}, {0xD5F, 0xD61}, +{0xD7A, 0xD7F}, {0xD85, 0xD96}, {0xD9A, 0xDB1}, {0xDB3, 0xDBB}, {0xDBD, 0xDBD}, {0xDC0, 0xDC6}, {0xE01, 0xE30}, {0xE32, 0xE33}, {0xE40, 0xE46}, {0xE81, 0xE82}, {0xE84, 0xE84}, {0xE86, 0xE8A}, +{0xE8C, 0xEA3}, {0xEA5, 0xEA5}, {0xEA7, 0xEB0}, {0xEB2, 0xEB3}, {0xEBD, 0xEBD}, {0xEC0, 0xEC4}, {0xEC6, 0xEC6}, {0xEDC, 0xEDF}, {0xF00, 0xF00}, {0xF40, 0xF47}, {0xF49, 0xF6C}, {0xF88, 0xF8C}, +{0x1000, 0x102A}, {0x103F, 0x103F}, {0x1050, 0x1055}, {0x105A, 0x105D}, {0x1061, 0x1061}, {0x1065, 0x1066}, {0x106E, 0x1070}, {0x1075, 0x1081}, {0x108E, 0x108E}, {0x10A0, 0x10C5}, {0x10C7, 0x10C7}, +{0x10CD, 0x10CD}, {0x10D0, 0x10FA}, {0x10FC, 0x1248}, {0x124A, 0x124D}, {0x1250, 0x1256}, {0x1258, 0x1258}, {0x125A, 0x125D}, {0x1260, 0x1288}, {0x128A, 0x128D}, {0x1290, 0x12B0}, {0x12B2, 0x12B5}, +{0x12B8, 0x12BE}, {0x12C0, 0x12C0}, {0x12C2, 0x12C5}, {0x12C8, 0x12D6}, {0x12D8, 0x1310}, {0x1312, 0x1315}, {0x1318, 0x135A}, {0x1380, 0x138F}, {0x13A0, 0x13F5}, {0x13F8, 0x13FD}, {0x1401, 0x166C}, +{0x166F, 0x167F}, {0x1681, 0x169A}, {0x16A0, 0x16EA}, {0x16F1, 0x16F8}, {0x1700, 0x170C}, {0x170E, 0x1711}, {0x1720, 0x1731}, {0x1740, 0x1751}, {0x1760, 0x176C}, {0x176E, 0x1770}, {0x1780, 0x17B3}, +{0x17D7, 0x17D7}, {0x17DC, 0x17DC}, {0x1820, 0x1878}, {0x1880, 0x1884}, {0x1887, 0x18A8}, {0x18AA, 0x18AA}, {0x18B0, 0x18F5}, {0x1900, 0x191E}, {0x1950, 0x196D}, {0x1970, 0x1974}, {0x1980, 0x19AB}, +{0x19B0, 0x19C9}, {0x1A00, 0x1A16}, {0x1A20, 0x1A54}, {0x1AA7, 0x1AA7}, {0x1B05, 0x1B33}, {0x1B45, 0x1B4B}, {0x1B83, 0x1BA0}, {0x1BAE, 0x1BAF}, {0x1BBA, 0x1BE5}, {0x1C00, 0x1C23}, {0x1C4D, 0x1C4F}, +{0x1C5A, 0x1C7D}, {0x1C80, 0x1C88}, {0x1C90, 0x1CBA}, {0x1CBD, 0x1CBF}, {0x1CE9, 0x1CEC}, {0x1CEE, 0x1CF3}, {0x1CF5, 0x1CF6}, {0x1CFA, 0x1CFA}, {0x1D00, 0x1DBF}, {0x1E00, 0x1F15}, {0x1F18, 0x1F1D}, +{0x1F20, 0x1F45}, {0x1F48, 0x1F4D}, {0x1F50, 0x1F57}, {0x1F59, 0x1F59}, {0x1F5B, 0x1F5B}, {0x1F5D, 0x1F5D}, {0x1F5F, 0x1F7D}, {0x1F80, 0x1FB4}, {0x1FB6, 0x1FBC}, {0x1FBE, 0x1FBE}, {0x1FC2, 0x1FC4}, +{0x1FC6, 0x1FCC}, {0x1FD0, 0x1FD3}, {0x1FD6, 0x1FDB}, {0x1FE0, 0x1FEC}, {0x1FF2, 0x1FF4}, {0x1FF6, 0x1FFC}, {0x2071, 0x2071}, {0x207F, 0x207F}, {0x2090, 0x209C}, {0x2102, 0x2102}, {0x2107, 0x2107}, +{0x210A, 0x2113}, {0x2115, 0x2115}, {0x2119, 0x211D}, {0x2124, 0x2124}, {0x2126, 0x2126}, {0x2128, 0x2128}, {0x212A, 0x212D}, {0x212F, 0x2139}, {0x213C, 0x213F}, {0x2145, 0x2149}, {0x214E, 0x214E}, +{0x2183, 0x2184}, {0x2C00, 0x2C2E}, {0x2C30, 0x2C5E}, {0x2C60, 0x2CE4}, {0x2CEB, 0x2CEE}, {0x2CF2, 0x2CF3}, {0x2D00, 0x2D25}, {0x2D27, 0x2D27}, {0x2D2D, 0x2D2D}, {0x2D30, 0x2D67}, {0x2D6F, 0x2D6F}, +{0x2D80, 0x2D96}, {0x2DA0, 0x2DA6}, {0x2DA8, 0x2DAE}, {0x2DB0, 0x2DB6}, {0x2DB8, 0x2DBE}, {0x2DC0, 0x2DC6}, {0x2DC8, 0x2DCE}, {0x2DD0, 0x2DD6}, {0x2DD8, 0x2DDE}, {0x2E2F, 0x2E2F}, {0x3005, 0x3006}, +{0x3031, 0x3035}, {0x303B, 0x303C}, {0x3041, 0x3096}, {0x309D, 0x309F}, {0x30A1, 0x30FA}, {0x30FC, 0x30FF}, {0x3105, 0x312F}, {0x3131, 0x318E}, {0x31A0, 0x31BF}, {0x31F0, 0x31FF}, {0x3400, 0x4DBF}, +{0x4E00, 0x9FFC}, {0xA000, 0xA48C}, {0xA4D0, 0xA4FD}, {0xA500, 0xA60C}, {0xA610, 0xA61F}, {0xA62A, 0xA62B}, {0xA640, 0xA66E}, {0xA67F, 0xA69D}, {0xA6A0, 0xA6E5}, {0xA717, 0xA71F}, {0xA722, 0xA788}, +{0xA78B, 0xA7BF}, {0xA7C2, 0xA7CA}, {0xA7F5, 0xA801}, {0xA803, 0xA805}, {0xA807, 0xA80A}, {0xA80C, 0xA822}, {0xA840, 0xA873}, {0xA882, 0xA8B3}, {0xA8F2, 0xA8F7}, {0xA8FB, 0xA8FB}, {0xA8FD, 0xA8FE}, +{0xA90A, 0xA925}, {0xA930, 0xA946}, {0xA960, 0xA97C}, {0xA984, 0xA9B2}, {0xA9CF, 0xA9CF}, {0xA9E0, 0xA9E4}, {0xA9E6, 0xA9EF}, {0xA9FA, 0xA9FE}, {0xAA00, 0xAA28}, {0xAA40, 0xAA42}, {0xAA44, 0xAA4B}, +{0xAA60, 0xAA76}, {0xAA7A, 0xAA7A}, {0xAA7E, 0xAAAF}, {0xAAB1, 0xAAB1}, {0xAAB5, 0xAAB6}, {0xAAB9, 0xAABD}, {0xAAC0, 0xAAC0}, {0xAAC2, 0xAAC2}, {0xAADB, 0xAADD}, {0xAAE0, 0xAAEA}, {0xAAF2, 0xAAF4}, +{0xAB01, 0xAB06}, {0xAB09, 0xAB0E}, {0xAB11, 0xAB16}, {0xAB20, 0xAB26}, {0xAB28, 0xAB2E}, {0xAB30, 0xAB5A}, {0xAB5C, 0xAB69}, {0xAB70, 0xABE2}, {0xAC00, 0xD7A3}, {0xD7B0, 0xD7C6}, {0xD7CB, 0xD7FB}, +{0xF900, 0xFA6D}, {0xFA70, 0xFAD9}, {0xFB00, 0xFB06}, {0xFB13, 0xFB17}, {0xFB1D, 0xFB1D}, {0xFB1F, 0xFB28}, {0xFB2A, 0xFB36}, {0xFB38, 0xFB3C}, {0xFB3E, 0xFB3E}, {0xFB40, 0xFB41}, {0xFB43, 0xFB44}, +{0xFB46, 0xFBB1}, {0xFBD3, 0xFD3D}, {0xFD50, 0xFD8F}, {0xFD92, 0xFDC7}, {0xFDF0, 0xFDFB}, {0xFE70, 0xFE74}, {0xFE76, 0xFEFC}, {0xFF21, 0xFF3A}, {0xFF41, 0xFF5A}, {0xFF66, 0xFFBE}, {0xFFC2, 0xFFC7}, +{0xFFCA, 0xFFCF}, {0xFFD2, 0xFFD7}, {0xFFDA, 0xFFDC}, {0x10000, 0x1000B}, {0x1000D, 0x10026}, {0x10028, 0x1003A}, {0x1003C, 0x1003D}, {0x1003F, 0x1004D}, {0x10050, 0x1005D}, {0x10080, 0x100FA}, +{0x10280, 0x1029C}, {0x102A0, 0x102D0}, {0x10300, 0x1031F}, {0x1032D, 0x10340}, {0x10342, 0x10349}, {0x10350, 0x10375}, {0x10380, 0x1039D}, {0x103A0, 0x103C3}, {0x103C8, 0x103CF}, {0x10400, 0x1049D}, +{0x104B0, 0x104D3}, {0x104D8, 0x104FB}, {0x10500, 0x10527}, {0x10530, 0x10563}, {0x10600, 0x10736}, {0x10740, 0x10755}, {0x10760, 0x10767}, {0x10800, 0x10805}, {0x10808, 0x10808}, {0x1080A, 0x10835}, +{0x10837, 0x10838}, {0x1083C, 0x1083C}, {0x1083F, 0x10855}, {0x10860, 0x10876}, {0x10880, 0x1089E}, {0x108E0, 0x108F2}, {0x108F4, 0x108F5}, {0x10900, 0x10915}, {0x10920, 0x10939}, {0x10980, 0x109B7}, +{0x109BE, 0x109BF}, {0x10A00, 0x10A00}, {0x10A10, 0x10A13}, {0x10A15, 0x10A17}, {0x10A19, 0x10A35}, {0x10A60, 0x10A7C}, {0x10A80, 0x10A9C}, {0x10AC0, 0x10AC7}, {0x10AC9, 0x10AE4}, {0x10B00, 0x10B35}, +{0x10B40, 0x10B55}, {0x10B60, 0x10B72}, {0x10B80, 0x10B91}, {0x10C00, 0x10C48}, {0x10C80, 0x10CB2}, {0x10CC0, 0x10CF2}, {0x10D00, 0x10D23}, {0x10E80, 0x10EA9}, {0x10EB0, 0x10EB1}, {0x10F00, 0x10F1C}, +{0x10F27, 0x10F27}, {0x10F30, 0x10F45}, {0x10FB0, 0x10FC4}, {0x10FE0, 0x10FF6}, {0x11003, 0x11037}, {0x11083, 0x110AF}, {0x110D0, 0x110E8}, {0x11103, 0x11126}, {0x11144, 0x11144}, {0x11147, 0x11147}, +{0x11150, 0x11172}, {0x11176, 0x11176}, {0x11183, 0x111B2}, {0x111C1, 0x111C4}, {0x111DA, 0x111DA}, {0x111DC, 0x111DC}, {0x11200, 0x11211}, {0x11213, 0x1122B}, {0x11280, 0x11286}, {0x11288, 0x11288}, +{0x1128A, 0x1128D}, {0x1128F, 0x1129D}, {0x1129F, 0x112A8}, {0x112B0, 0x112DE}, {0x11305, 0x1130C}, {0x1130F, 0x11310}, {0x11313, 0x11328}, {0x1132A, 0x11330}, {0x11332, 0x11333}, {0x11335, 0x11339}, +{0x1133D, 0x1133D}, {0x11350, 0x11350}, {0x1135D, 0x11361}, {0x11400, 0x11434}, {0x11447, 0x1144A}, {0x1145F, 0x11461}, {0x11480, 0x114AF}, {0x114C4, 0x114C5}, {0x114C7, 0x114C7}, {0x11580, 0x115AE}, +{0x115D8, 0x115DB}, {0x11600, 0x1162F}, {0x11644, 0x11644}, {0x11680, 0x116AA}, {0x116B8, 0x116B8}, {0x11700, 0x1171A}, {0x11800, 0x1182B}, {0x118A0, 0x118DF}, {0x118FF, 0x11906}, {0x11909, 0x11909}, +{0x1190C, 0x11913}, {0x11915, 0x11916}, {0x11918, 0x1192F}, {0x1193F, 0x1193F}, {0x11941, 0x11941}, {0x119A0, 0x119A7}, {0x119AA, 0x119D0}, {0x119E1, 0x119E1}, {0x119E3, 0x119E3}, {0x11A00, 0x11A00}, +{0x11A0B, 0x11A32}, {0x11A3A, 0x11A3A}, {0x11A50, 0x11A50}, {0x11A5C, 0x11A89}, {0x11A9D, 0x11A9D}, {0x11AC0, 0x11AF8}, {0x11C00, 0x11C08}, {0x11C0A, 0x11C2E}, {0x11C40, 0x11C40}, {0x11C72, 0x11C8F}, +{0x11D00, 0x11D06}, {0x11D08, 0x11D09}, {0x11D0B, 0x11D30}, {0x11D46, 0x11D46}, {0x11D60, 0x11D65}, {0x11D67, 0x11D68}, {0x11D6A, 0x11D89}, {0x11D98, 0x11D98}, {0x11EE0, 0x11EF2}, {0x11FB0, 0x11FB0}, +{0x12000, 0x12399}, {0x12480, 0x12543}, {0x13000, 0x1342E}, {0x14400, 0x14646}, {0x16800, 0x16A38}, {0x16A40, 0x16A5E}, {0x16AD0, 0x16AED}, {0x16B00, 0x16B2F}, {0x16B40, 0x16B43}, {0x16B63, 0x16B77}, +{0x16B7D, 0x16B8F}, {0x16E40, 0x16E7F}, {0x16F00, 0x16F4A}, {0x16F50, 0x16F50}, {0x16F93, 0x16F9F}, {0x16FE0, 0x16FE1}, {0x16FE3, 0x16FE3}, {0x17000, 0x187F7}, {0x18800, 0x18CD5}, {0x18D00, 0x18D08}, +{0x1B000, 0x1B11E}, {0x1B150, 0x1B152}, {0x1B164, 0x1B167}, {0x1B170, 0x1B2FB}, {0x1BC00, 0x1BC6A}, {0x1BC70, 0x1BC7C}, {0x1BC80, 0x1BC88}, {0x1BC90, 0x1BC99}, {0x1D400, 0x1D454}, {0x1D456, 0x1D49C}, +{0x1D49E, 0x1D49F}, {0x1D4A2, 0x1D4A2}, {0x1D4A5, 0x1D4A6}, {0x1D4A9, 0x1D4AC}, {0x1D4AE, 0x1D4B9}, {0x1D4BB, 0x1D4BB}, {0x1D4BD, 0x1D4C3}, {0x1D4C5, 0x1D505}, {0x1D507, 0x1D50A}, {0x1D50D, 0x1D514}, +{0x1D516, 0x1D51C}, {0x1D51E, 0x1D539}, {0x1D53B, 0x1D53E}, {0x1D540, 0x1D544}, {0x1D546, 0x1D546}, {0x1D54A, 0x1D550}, {0x1D552, 0x1D6A5}, {0x1D6A8, 0x1D6C0}, {0x1D6C2, 0x1D6DA}, {0x1D6DC, 0x1D6FA}, +{0x1D6FC, 0x1D714}, {0x1D716, 0x1D734}, {0x1D736, 0x1D74E}, {0x1D750, 0x1D76E}, {0x1D770, 0x1D788}, {0x1D78A, 0x1D7A8}, {0x1D7AA, 0x1D7C2}, {0x1D7C4, 0x1D7CB}, {0x1E100, 0x1E12C}, {0x1E137, 0x1E13D}, +{0x1E14E, 0x1E14E}, {0x1E2C0, 0x1E2EB}, {0x1E800, 0x1E8C4}, {0x1E900, 0x1E943}, {0x1E94B, 0x1E94B}, {0x1EE00, 0x1EE03}, {0x1EE05, 0x1EE1F}, {0x1EE21, 0x1EE22}, {0x1EE24, 0x1EE24}, {0x1EE27, 0x1EE27}, +{0x1EE29, 0x1EE32}, {0x1EE34, 0x1EE37}, {0x1EE39, 0x1EE39}, {0x1EE3B, 0x1EE3B}, {0x1EE42, 0x1EE42}, {0x1EE47, 0x1EE47}, {0x1EE49, 0x1EE49}, {0x1EE4B, 0x1EE4B}, {0x1EE4D, 0x1EE4F}, {0x1EE51, 0x1EE52}, +{0x1EE54, 0x1EE54}, {0x1EE57, 0x1EE57}, {0x1EE59, 0x1EE59}, {0x1EE5B, 0x1EE5B}, {0x1EE5D, 0x1EE5D}, {0x1EE5F, 0x1EE5F}, {0x1EE61, 0x1EE62}, {0x1EE64, 0x1EE64}, {0x1EE67, 0x1EE6A}, {0x1EE6C, 0x1EE72}, +{0x1EE74, 0x1EE77}, {0x1EE79, 0x1EE7C}, {0x1EE7E, 0x1EE7E}, {0x1EE80, 0x1EE89}, {0x1EE8B, 0x1EE9B}, {0x1EEA1, 0x1EEA3}, {0x1EEA5, 0x1EEA9}, {0x1EEAB, 0x1EEBB}, {0x20000, 0x2A6DD}, {0x2A700, 0x2B734}, +{0x2B740, 0x2B81D}, {0x2B820, 0x2CEA1}, {0x2CEB0, 0x2EBE0}, {0x2F800, 0x2FA1D}, {0x30000, 0x3134A}, +}; + +static const std::vector> whitespace_ranges = { +{0x9, 0xD}, {0x1C, 0x20}, {0x85, 0x85}, {0xA0, 0xA0}, {0x1680, 0x1680}, {0x2000, 0x200A}, {0x2028, 0x2029}, {0x202F, 0x202F}, {0x205F, 0x205F}, {0x3000, 0x3000}, +}; + +static const std::vector> accent_mark_ranges = { +{0x300, 0x36F}, {0x483, 0x489}, {0x591, 0x5BD}, {0x5BF, 0x5BF}, {0x5C1, 0x5C2}, {0x5C4, 0x5C5}, {0x5C7, 0x5C7}, {0x610, 0x61A}, {0x64B, 0x65F}, {0x670, 0x670}, {0x6D6, 0x6DC}, {0x6DF, 0x6E4}, +{0x6E7, 0x6E8}, {0x6EA, 0x6ED}, {0x711, 0x711}, {0x730, 0x74A}, {0x7A6, 0x7B0}, {0x7EB, 0x7F3}, {0x7FD, 0x7FD}, {0x816, 0x819}, {0x81B, 0x823}, {0x825, 0x827}, {0x829, 0x82D}, {0x859, 0x85B}, +{0x8D3, 0x8E1}, {0x8E3, 0x903}, {0x93A, 0x93C}, {0x93E, 0x94F}, {0x951, 0x957}, {0x962, 0x963}, {0x981, 0x983}, {0x9BC, 0x9BC}, {0x9BE, 0x9C4}, {0x9C7, 0x9C8}, {0x9CB, 0x9CD}, {0x9D7, 0x9D7}, +{0x9E2, 0x9E3}, {0x9FE, 0x9FE}, {0xA01, 0xA03}, {0xA3C, 0xA3C}, {0xA3E, 0xA42}, {0xA47, 0xA48}, {0xA4B, 0xA4D}, {0xA51, 0xA51}, {0xA70, 0xA71}, {0xA75, 0xA75}, {0xA81, 0xA83}, {0xABC, 0xABC}, +{0xABE, 0xAC5}, {0xAC7, 0xAC9}, {0xACB, 0xACD}, {0xAE2, 0xAE3}, {0xAFA, 0xAFF}, {0xB01, 0xB03}, {0xB3C, 0xB3C}, {0xB3E, 0xB44}, {0xB47, 0xB48}, {0xB4B, 0xB4D}, {0xB55, 0xB57}, {0xB62, 0xB63}, +{0xB82, 0xB82}, {0xBBE, 0xBC2}, {0xBC6, 0xBC8}, {0xBCA, 0xBCD}, {0xBD7, 0xBD7}, {0xC00, 0xC04}, {0xC3E, 0xC44}, {0xC46, 0xC48}, {0xC4A, 0xC4D}, {0xC55, 0xC56}, {0xC62, 0xC63}, {0xC81, 0xC83}, +{0xCBC, 0xCBC}, {0xCBE, 0xCC4}, {0xCC6, 0xCC8}, {0xCCA, 0xCCD}, {0xCD5, 0xCD6}, {0xCE2, 0xCE3}, {0xD00, 0xD03}, {0xD3B, 0xD3C}, {0xD3E, 0xD44}, {0xD46, 0xD48}, {0xD4A, 0xD4D}, {0xD57, 0xD57}, +{0xD62, 0xD63}, {0xD81, 0xD83}, {0xDCA, 0xDCA}, {0xDCF, 0xDD4}, {0xDD6, 0xDD6}, {0xDD8, 0xDDF}, {0xDF2, 0xDF3}, {0xE31, 0xE31}, {0xE34, 0xE3A}, {0xE47, 0xE4E}, {0xEB1, 0xEB1}, {0xEB4, 0xEBC}, +{0xEC8, 0xECD}, {0xF18, 0xF19}, {0xF35, 0xF35}, {0xF37, 0xF37}, {0xF39, 0xF39}, {0xF3E, 0xF3F}, {0xF71, 0xF84}, {0xF86, 0xF87}, {0xF8D, 0xF97}, {0xF99, 0xFBC}, {0xFC6, 0xFC6}, {0x102B, 0x103E}, +{0x1056, 0x1059}, {0x105E, 0x1060}, {0x1062, 0x1064}, {0x1067, 0x106D}, {0x1071, 0x1074}, {0x1082, 0x108D}, {0x108F, 0x108F}, {0x109A, 0x109D}, {0x135D, 0x135F}, {0x1712, 0x1714}, {0x1732, 0x1734}, +{0x1752, 0x1753}, {0x1772, 0x1773}, {0x17B4, 0x17D3}, {0x17DD, 0x17DD}, {0x180B, 0x180D}, {0x1885, 0x1886}, {0x18A9, 0x18A9}, {0x1920, 0x192B}, {0x1930, 0x193B}, {0x1A17, 0x1A1B}, {0x1A55, 0x1A5E}, +{0x1A60, 0x1A7C}, {0x1A7F, 0x1A7F}, {0x1AB0, 0x1AC0}, {0x1B00, 0x1B04}, {0x1B34, 0x1B44}, {0x1B6B, 0x1B73}, {0x1B80, 0x1B82}, {0x1BA1, 0x1BAD}, {0x1BE6, 0x1BF3}, {0x1C24, 0x1C37}, {0x1CD0, 0x1CD2}, +{0x1CD4, 0x1CE8}, {0x1CED, 0x1CED}, {0x1CF4, 0x1CF4}, {0x1CF7, 0x1CF9}, {0x1DC0, 0x1DF9}, {0x1DFB, 0x1DFF}, {0x20D0, 0x20F0}, {0x2CEF, 0x2CF1}, {0x2D7F, 0x2D7F}, {0x2DE0, 0x2DFF}, {0x302A, 0x302F}, +{0x3099, 0x309A}, {0xA66F, 0xA672}, {0xA674, 0xA67D}, {0xA69E, 0xA69F}, {0xA6F0, 0xA6F1}, {0xA802, 0xA802}, {0xA806, 0xA806}, {0xA80B, 0xA80B}, {0xA823, 0xA827}, {0xA82C, 0xA82C}, {0xA880, 0xA881}, +{0xA8B4, 0xA8C5}, {0xA8E0, 0xA8F1}, {0xA8FF, 0xA8FF}, {0xA926, 0xA92D}, {0xA947, 0xA953}, {0xA980, 0xA983}, {0xA9B3, 0xA9C0}, {0xA9E5, 0xA9E5}, {0xAA29, 0xAA36}, {0xAA43, 0xAA43}, {0xAA4C, 0xAA4D}, +{0xAA7B, 0xAA7D}, {0xAAB0, 0xAAB0}, {0xAAB2, 0xAAB4}, {0xAAB7, 0xAAB8}, {0xAABE, 0xAABF}, {0xAAC1, 0xAAC1}, {0xAAEB, 0xAAEF}, {0xAAF5, 0xAAF6}, {0xABE3, 0xABEA}, {0xABEC, 0xABED}, {0xFB1E, 0xFB1E}, +{0xFE00, 0xFE0F}, {0xFE20, 0xFE2F}, {0x101FD, 0x101FD}, {0x102E0, 0x102E0}, {0x10376, 0x1037A}, {0x10A01, 0x10A03}, {0x10A05, 0x10A06}, {0x10A0C, 0x10A0F}, {0x10A38, 0x10A3A}, {0x10A3F, 0x10A3F}, +{0x10AE5, 0x10AE6}, {0x10D24, 0x10D27}, {0x10EAB, 0x10EAC}, {0x10F46, 0x10F50}, {0x11000, 0x11002}, {0x11038, 0x11046}, {0x1107F, 0x11082}, {0x110B0, 0x110BA}, {0x11100, 0x11102}, {0x11127, 0x11134}, +{0x11145, 0x11146}, {0x11173, 0x11173}, {0x11180, 0x11182}, {0x111B3, 0x111C0}, {0x111C9, 0x111CC}, {0x111CE, 0x111CF}, {0x1122C, 0x11237}, {0x1123E, 0x1123E}, {0x112DF, 0x112EA}, {0x11300, 0x11303}, +{0x1133B, 0x1133C}, {0x1133E, 0x11344}, {0x11347, 0x11348}, {0x1134B, 0x1134D}, {0x11357, 0x11357}, {0x11362, 0x11363}, {0x11366, 0x1136C}, {0x11370, 0x11374}, {0x11435, 0x11446}, {0x1145E, 0x1145E}, +{0x114B0, 0x114C3}, {0x115AF, 0x115B5}, {0x115B8, 0x115C0}, {0x115DC, 0x115DD}, {0x11630, 0x11640}, {0x116AB, 0x116B7}, {0x1171D, 0x1172B}, {0x1182C, 0x1183A}, {0x11930, 0x11935}, {0x11937, 0x11938}, +{0x1193B, 0x1193E}, {0x11940, 0x11940}, {0x11942, 0x11943}, {0x119D1, 0x119D7}, {0x119DA, 0x119E0}, {0x119E4, 0x119E4}, {0x11A01, 0x11A0A}, {0x11A33, 0x11A39}, {0x11A3B, 0x11A3E}, {0x11A47, 0x11A47}, +{0x11A51, 0x11A5B}, {0x11A8A, 0x11A99}, {0x11C2F, 0x11C36}, {0x11C38, 0x11C3F}, {0x11C92, 0x11CA7}, {0x11CA9, 0x11CB6}, {0x11D31, 0x11D36}, {0x11D3A, 0x11D3A}, {0x11D3C, 0x11D3D}, {0x11D3F, 0x11D45}, +{0x11D47, 0x11D47}, {0x11D8A, 0x11D8E}, {0x11D90, 0x11D91}, {0x11D93, 0x11D97}, {0x11EF3, 0x11EF6}, {0x16AF0, 0x16AF4}, {0x16B30, 0x16B36}, {0x16F4F, 0x16F4F}, {0x16F51, 0x16F87}, {0x16F8F, 0x16F92}, +{0x16FE4, 0x16FE4}, {0x16FF0, 0x16FF1}, {0x1BC9D, 0x1BC9E}, {0x1D165, 0x1D169}, {0x1D16D, 0x1D172}, {0x1D17B, 0x1D182}, {0x1D185, 0x1D18B}, {0x1D1AA, 0x1D1AD}, {0x1D242, 0x1D244}, {0x1DA00, 0x1DA36}, +{0x1DA3B, 0x1DA6C}, {0x1DA75, 0x1DA75}, {0x1DA84, 0x1DA84}, {0x1DA9B, 0x1DA9F}, {0x1DAA1, 0x1DAAF}, {0x1E000, 0x1E006}, {0x1E008, 0x1E018}, {0x1E01B, 0x1E021}, {0x1E023, 0x1E024}, {0x1E026, 0x1E02A}, +{0x1E130, 0x1E136}, {0x1E2EC, 0x1E2EF}, {0x1E8D0, 0x1E8D6}, {0x1E944, 0x1E94A}, {0xE0100, 0xE01EF}, +}; + +static const std::vector> punctuation_ranges = { +{0x21, 0x23}, {0x25, 0x2A}, {0x2C, 0x2F}, {0x3A, 0x3B}, {0x3F, 0x40}, {0x5B, 0x5D}, {0x5F, 0x5F}, {0x7B, 0x7B}, {0x7D, 0x7D}, {0xA1, 0xA1}, {0xA7, 0xA7}, {0xAB, 0xAB}, {0xB6, 0xB7}, {0xBB, 0xBB}, +{0xBF, 0xBF}, {0x37E, 0x37E}, {0x387, 0x387}, {0x55A, 0x55F}, {0x589, 0x58A}, {0x5BE, 0x5BE}, {0x5C0, 0x5C0}, {0x5C3, 0x5C3}, {0x5C6, 0x5C6}, {0x5F3, 0x5F4}, {0x609, 0x60A}, {0x60C, 0x60D}, +{0x61B, 0x61B}, {0x61E, 0x61F}, {0x66A, 0x66D}, {0x6D4, 0x6D4}, {0x700, 0x70D}, {0x7F7, 0x7F9}, {0x830, 0x83E}, {0x85E, 0x85E}, {0x964, 0x965}, {0x970, 0x970}, {0x9FD, 0x9FD}, {0xA76, 0xA76}, +{0xAF0, 0xAF0}, {0xC77, 0xC77}, {0xC84, 0xC84}, {0xDF4, 0xDF4}, {0xE4F, 0xE4F}, {0xE5A, 0xE5B}, {0xF04, 0xF12}, {0xF14, 0xF14}, {0xF3A, 0xF3D}, {0xF85, 0xF85}, {0xFD0, 0xFD4}, {0xFD9, 0xFDA}, +{0x104A, 0x104F}, {0x10FB, 0x10FB}, {0x1360, 0x1368}, {0x1400, 0x1400}, {0x166E, 0x166E}, {0x169B, 0x169C}, {0x16EB, 0x16ED}, {0x1735, 0x1736}, {0x17D4, 0x17D6}, {0x17D8, 0x17DA}, {0x1800, 0x180A}, +{0x1944, 0x1945}, {0x1A1E, 0x1A1F}, {0x1AA0, 0x1AA6}, {0x1AA8, 0x1AAD}, {0x1B5A, 0x1B60}, {0x1BFC, 0x1BFF}, {0x1C3B, 0x1C3F}, {0x1C7E, 0x1C7F}, {0x1CC0, 0x1CC7}, {0x1CD3, 0x1CD3}, {0x2010, 0x2027}, +{0x2030, 0x2043}, {0x2045, 0x2051}, {0x2053, 0x205E}, {0x207D, 0x207E}, {0x208D, 0x208E}, {0x2308, 0x230B}, {0x2329, 0x232A}, {0x2768, 0x2775}, {0x27C5, 0x27C6}, {0x27E6, 0x27EF}, {0x2983, 0x2998}, +{0x29D8, 0x29DB}, {0x29FC, 0x29FD}, {0x2CF9, 0x2CFC}, {0x2CFE, 0x2CFF}, {0x2D70, 0x2D70}, {0x2E00, 0x2E2E}, {0x2E30, 0x2E4F}, {0x2E52, 0x2E52}, {0x3001, 0x3003}, {0x3008, 0x3011}, {0x3014, 0x301F}, +{0x3030, 0x3030}, {0x303D, 0x303D}, {0x30A0, 0x30A0}, {0x30FB, 0x30FB}, {0xA4FE, 0xA4FF}, {0xA60D, 0xA60F}, {0xA673, 0xA673}, {0xA67E, 0xA67E}, {0xA6F2, 0xA6F7}, {0xA874, 0xA877}, {0xA8CE, 0xA8CF}, +{0xA8F8, 0xA8FA}, {0xA8FC, 0xA8FC}, {0xA92E, 0xA92F}, {0xA95F, 0xA95F}, {0xA9C1, 0xA9CD}, {0xA9DE, 0xA9DF}, {0xAA5C, 0xAA5F}, {0xAADE, 0xAADF}, {0xAAF0, 0xAAF1}, {0xABEB, 0xABEB}, {0xFD3E, 0xFD3F}, +{0xFE10, 0xFE19}, {0xFE30, 0xFE52}, {0xFE54, 0xFE61}, {0xFE63, 0xFE63}, {0xFE68, 0xFE68}, {0xFE6A, 0xFE6B}, {0xFF01, 0xFF03}, {0xFF05, 0xFF0A}, {0xFF0C, 0xFF0F}, {0xFF1A, 0xFF1B}, {0xFF1F, 0xFF20}, +{0xFF3B, 0xFF3D}, {0xFF3F, 0xFF3F}, {0xFF5B, 0xFF5B}, {0xFF5D, 0xFF5D}, {0xFF5F, 0xFF65}, {0x10100, 0x10102}, {0x1039F, 0x1039F}, {0x103D0, 0x103D0}, {0x1056F, 0x1056F}, {0x10857, 0x10857}, +{0x1091F, 0x1091F}, {0x1093F, 0x1093F}, {0x10A50, 0x10A58}, {0x10A7F, 0x10A7F}, {0x10AF0, 0x10AF6}, {0x10B39, 0x10B3F}, {0x10B99, 0x10B9C}, {0x10EAD, 0x10EAD}, {0x10F55, 0x10F59}, {0x11047, 0x1104D}, +{0x110BB, 0x110BC}, {0x110BE, 0x110C1}, {0x11140, 0x11143}, {0x11174, 0x11175}, {0x111C5, 0x111C8}, {0x111CD, 0x111CD}, {0x111DB, 0x111DB}, {0x111DD, 0x111DF}, {0x11238, 0x1123D}, {0x112A9, 0x112A9}, +{0x1144B, 0x1144F}, {0x1145A, 0x1145B}, {0x1145D, 0x1145D}, {0x114C6, 0x114C6}, {0x115C1, 0x115D7}, {0x11641, 0x11643}, {0x11660, 0x1166C}, {0x1173C, 0x1173E}, {0x1183B, 0x1183B}, {0x11944, 0x11946}, +{0x119E2, 0x119E2}, {0x11A3F, 0x11A46}, {0x11A9A, 0x11A9C}, {0x11A9E, 0x11AA2}, {0x11C41, 0x11C45}, {0x11C70, 0x11C71}, {0x11EF7, 0x11EF8}, {0x11FFF, 0x11FFF}, {0x12470, 0x12474}, {0x16A6E, 0x16A6F}, +{0x16AF5, 0x16AF5}, {0x16B37, 0x16B3B}, {0x16B44, 0x16B44}, {0x16E97, 0x16E9A}, {0x16FE2, 0x16FE2}, {0x1BC9F, 0x1BC9F}, {0x1DA87, 0x1DA8B}, {0x1E95E, 0x1E95F}, +}; + +static const std::vector> symbol_ranges = { +{0x24, 0x24}, {0x2B, 0x2B}, {0x3C, 0x3E}, {0x5E, 0x5E}, {0x60, 0x60}, {0x7C, 0x7C}, {0x7E, 0x7E}, {0xA2, 0xA6}, {0xA8, 0xA9}, {0xAC, 0xAC}, {0xAE, 0xB1}, {0xB4, 0xB4}, {0xB8, 0xB8}, {0xD7, 0xD7}, +{0xF7, 0xF7}, {0x2C2, 0x2C5}, {0x2D2, 0x2DF}, {0x2E5, 0x2EB}, {0x2ED, 0x2ED}, {0x2EF, 0x2FF}, {0x375, 0x375}, {0x384, 0x385}, {0x3F6, 0x3F6}, {0x482, 0x482}, {0x58D, 0x58F}, {0x606, 0x608}, +{0x60B, 0x60B}, {0x60E, 0x60F}, {0x6DE, 0x6DE}, {0x6E9, 0x6E9}, {0x6FD, 0x6FE}, {0x7F6, 0x7F6}, {0x7FE, 0x7FF}, {0x9F2, 0x9F3}, {0x9FA, 0x9FB}, {0xAF1, 0xAF1}, {0xB70, 0xB70}, {0xBF3, 0xBFA}, +{0xC7F, 0xC7F}, {0xD4F, 0xD4F}, {0xD79, 0xD79}, {0xE3F, 0xE3F}, {0xF01, 0xF03}, {0xF13, 0xF13}, {0xF15, 0xF17}, {0xF1A, 0xF1F}, {0xF34, 0xF34}, {0xF36, 0xF36}, {0xF38, 0xF38}, {0xFBE, 0xFC5}, +{0xFC7, 0xFCC}, {0xFCE, 0xFCF}, {0xFD5, 0xFD8}, {0x109E, 0x109F}, {0x1390, 0x1399}, {0x166D, 0x166D}, {0x17DB, 0x17DB}, {0x1940, 0x1940}, {0x19DE, 0x19FF}, {0x1B61, 0x1B6A}, {0x1B74, 0x1B7C}, +{0x1FBD, 0x1FBD}, {0x1FBF, 0x1FC1}, {0x1FCD, 0x1FCF}, {0x1FDD, 0x1FDF}, {0x1FED, 0x1FEF}, {0x1FFD, 0x1FFE}, {0x2044, 0x2044}, {0x2052, 0x2052}, {0x207A, 0x207C}, {0x208A, 0x208C}, {0x20A0, 0x20BF}, +{0x2100, 0x2101}, {0x2103, 0x2106}, {0x2108, 0x2109}, {0x2114, 0x2114}, {0x2116, 0x2118}, {0x211E, 0x2123}, {0x2125, 0x2125}, {0x2127, 0x2127}, {0x2129, 0x2129}, {0x212E, 0x212E}, {0x213A, 0x213B}, +{0x2140, 0x2144}, {0x214A, 0x214D}, {0x214F, 0x214F}, {0x218A, 0x218B}, {0x2190, 0x2307}, {0x230C, 0x2328}, {0x232B, 0x2426}, {0x2440, 0x244A}, {0x249C, 0x24E9}, {0x2500, 0x2767}, {0x2794, 0x27C4}, +{0x27C7, 0x27E5}, {0x27F0, 0x2982}, {0x2999, 0x29D7}, {0x29DC, 0x29FB}, {0x29FE, 0x2B73}, {0x2B76, 0x2B95}, {0x2B97, 0x2BFF}, {0x2CE5, 0x2CEA}, {0x2E50, 0x2E51}, {0x2E80, 0x2E99}, {0x2E9B, 0x2EF3}, +{0x2F00, 0x2FD5}, {0x2FF0, 0x2FFB}, {0x3004, 0x3004}, {0x3012, 0x3013}, {0x3020, 0x3020}, {0x3036, 0x3037}, {0x303E, 0x303F}, {0x309B, 0x309C}, {0x3190, 0x3191}, {0x3196, 0x319F}, {0x31C0, 0x31E3}, +{0x3200, 0x321E}, {0x322A, 0x3247}, {0x3250, 0x3250}, {0x3260, 0x327F}, {0x328A, 0x32B0}, {0x32C0, 0x33FF}, {0x4DC0, 0x4DFF}, {0xA490, 0xA4C6}, {0xA700, 0xA716}, {0xA720, 0xA721}, {0xA789, 0xA78A}, +{0xA828, 0xA82B}, {0xA836, 0xA839}, {0xAA77, 0xAA79}, {0xAB5B, 0xAB5B}, {0xAB6A, 0xAB6B}, {0xFB29, 0xFB29}, {0xFBB2, 0xFBC1}, {0xFDFC, 0xFDFD}, {0xFE62, 0xFE62}, {0xFE64, 0xFE66}, {0xFE69, 0xFE69}, +{0xFF04, 0xFF04}, {0xFF0B, 0xFF0B}, {0xFF1C, 0xFF1E}, {0xFF3E, 0xFF3E}, {0xFF40, 0xFF40}, {0xFF5C, 0xFF5C}, {0xFF5E, 0xFF5E}, {0xFFE0, 0xFFE6}, {0xFFE8, 0xFFEE}, {0xFFFC, 0xFFFD}, {0x10137, 0x1013F}, +{0x10179, 0x10189}, {0x1018C, 0x1018E}, {0x10190, 0x1019C}, {0x101A0, 0x101A0}, {0x101D0, 0x101FC}, {0x10877, 0x10878}, {0x10AC8, 0x10AC8}, {0x1173F, 0x1173F}, {0x11FD5, 0x11FF1}, {0x16B3C, 0x16B3F}, +{0x16B45, 0x16B45}, {0x1BC9C, 0x1BC9C}, {0x1D000, 0x1D0F5}, {0x1D100, 0x1D126}, {0x1D129, 0x1D164}, {0x1D16A, 0x1D16C}, {0x1D183, 0x1D184}, {0x1D18C, 0x1D1A9}, {0x1D1AE, 0x1D1E8}, {0x1D200, 0x1D241}, +{0x1D245, 0x1D245}, {0x1D300, 0x1D356}, {0x1D6C1, 0x1D6C1}, {0x1D6DB, 0x1D6DB}, {0x1D6FB, 0x1D6FB}, {0x1D715, 0x1D715}, {0x1D735, 0x1D735}, {0x1D74F, 0x1D74F}, {0x1D76F, 0x1D76F}, {0x1D789, 0x1D789}, +{0x1D7A9, 0x1D7A9}, {0x1D7C3, 0x1D7C3}, {0x1D800, 0x1D9FF}, {0x1DA37, 0x1DA3A}, {0x1DA6D, 0x1DA74}, {0x1DA76, 0x1DA83}, {0x1DA85, 0x1DA86}, {0x1E14F, 0x1E14F}, {0x1E2FF, 0x1E2FF}, {0x1ECAC, 0x1ECAC}, +{0x1ECB0, 0x1ECB0}, {0x1ED2E, 0x1ED2E}, {0x1EEF0, 0x1EEF1}, {0x1F000, 0x1F02B}, {0x1F030, 0x1F093}, {0x1F0A0, 0x1F0AE}, {0x1F0B1, 0x1F0BF}, {0x1F0C1, 0x1F0CF}, {0x1F0D1, 0x1F0F5}, {0x1F10D, 0x1F1AD}, +{0x1F1E6, 0x1F202}, {0x1F210, 0x1F23B}, {0x1F240, 0x1F248}, {0x1F250, 0x1F251}, {0x1F260, 0x1F265}, {0x1F300, 0x1F6D7}, {0x1F6E0, 0x1F6EC}, {0x1F6F0, 0x1F6FC}, {0x1F700, 0x1F773}, {0x1F780, 0x1F7D8}, +{0x1F7E0, 0x1F7EB}, {0x1F800, 0x1F80B}, {0x1F810, 0x1F847}, {0x1F850, 0x1F859}, {0x1F860, 0x1F887}, {0x1F890, 0x1F8AD}, {0x1F8B0, 0x1F8B1}, {0x1F900, 0x1F978}, {0x1F97A, 0x1F9CB}, {0x1F9CD, 0x1FA53}, +{0x1FA60, 0x1FA6D}, {0x1FA70, 0x1FA74}, {0x1FA78, 0x1FA7A}, {0x1FA80, 0x1FA86}, {0x1FA90, 0x1FAA8}, {0x1FAB0, 0x1FAB6}, {0x1FAC0, 0x1FAC2}, {0x1FAD0, 0x1FAD6}, {0x1FB00, 0x1FB92}, {0x1FB94, 0x1FBCA}, +}; + +static const std::vector> control_ranges = { +{0x0, 0x8}, {0xE, 0x1B}, {0x7F, 0x84}, {0x86, 0x9F}, {0xAD, 0xAD}, {0x378, 0x379}, {0x380, 0x383}, {0x38B, 0x38B}, {0x38D, 0x38D}, {0x3A2, 0x3A2}, {0x530, 0x530}, {0x557, 0x558}, {0x58B, 0x58C}, +{0x590, 0x590}, {0x5C8, 0x5CF}, {0x5EB, 0x5EE}, {0x5F5, 0x605}, {0x61C, 0x61D}, {0x6DD, 0x6DD}, {0x70E, 0x70F}, {0x74B, 0x74C}, {0x7B2, 0x7BF}, {0x7FB, 0x7FC}, {0x82E, 0x82F}, {0x83F, 0x83F}, +{0x85C, 0x85D}, {0x85F, 0x85F}, {0x86B, 0x89F}, {0x8B5, 0x8B5}, {0x8C8, 0x8D2}, {0x8E2, 0x8E2}, {0x984, 0x984}, {0x98D, 0x98E}, {0x991, 0x992}, {0x9A9, 0x9A9}, {0x9B1, 0x9B1}, {0x9B3, 0x9B5}, +{0x9BA, 0x9BB}, {0x9C5, 0x9C6}, {0x9C9, 0x9CA}, {0x9CF, 0x9D6}, {0x9D8, 0x9DB}, {0x9DE, 0x9DE}, {0x9E4, 0x9E5}, {0x9FF, 0xA00}, {0xA04, 0xA04}, {0xA0B, 0xA0E}, {0xA11, 0xA12}, {0xA29, 0xA29}, +{0xA31, 0xA31}, {0xA34, 0xA34}, {0xA37, 0xA37}, {0xA3A, 0xA3B}, {0xA3D, 0xA3D}, {0xA43, 0xA46}, {0xA49, 0xA4A}, {0xA4E, 0xA50}, {0xA52, 0xA58}, {0xA5D, 0xA5D}, {0xA5F, 0xA65}, {0xA77, 0xA80}, +{0xA84, 0xA84}, {0xA8E, 0xA8E}, {0xA92, 0xA92}, {0xAA9, 0xAA9}, {0xAB1, 0xAB1}, {0xAB4, 0xAB4}, {0xABA, 0xABB}, {0xAC6, 0xAC6}, {0xACA, 0xACA}, {0xACE, 0xACF}, {0xAD1, 0xADF}, {0xAE4, 0xAE5}, +{0xAF2, 0xAF8}, {0xB00, 0xB00}, {0xB04, 0xB04}, {0xB0D, 0xB0E}, {0xB11, 0xB12}, {0xB29, 0xB29}, {0xB31, 0xB31}, {0xB34, 0xB34}, {0xB3A, 0xB3B}, {0xB45, 0xB46}, {0xB49, 0xB4A}, {0xB4E, 0xB54}, +{0xB58, 0xB5B}, {0xB5E, 0xB5E}, {0xB64, 0xB65}, {0xB78, 0xB81}, {0xB84, 0xB84}, {0xB8B, 0xB8D}, {0xB91, 0xB91}, {0xB96, 0xB98}, {0xB9B, 0xB9B}, {0xB9D, 0xB9D}, {0xBA0, 0xBA2}, {0xBA5, 0xBA7}, +{0xBAB, 0xBAD}, {0xBBA, 0xBBD}, {0xBC3, 0xBC5}, {0xBC9, 0xBC9}, {0xBCE, 0xBCF}, {0xBD1, 0xBD6}, {0xBD8, 0xBE5}, {0xBFB, 0xBFF}, {0xC0D, 0xC0D}, {0xC11, 0xC11}, {0xC29, 0xC29}, {0xC3A, 0xC3C}, +{0xC45, 0xC45}, {0xC49, 0xC49}, {0xC4E, 0xC54}, {0xC57, 0xC57}, {0xC5B, 0xC5F}, {0xC64, 0xC65}, {0xC70, 0xC76}, {0xC8D, 0xC8D}, {0xC91, 0xC91}, {0xCA9, 0xCA9}, {0xCB4, 0xCB4}, {0xCBA, 0xCBB}, +{0xCC5, 0xCC5}, {0xCC9, 0xCC9}, {0xCCE, 0xCD4}, {0xCD7, 0xCDD}, {0xCDF, 0xCDF}, {0xCE4, 0xCE5}, {0xCF0, 0xCF0}, {0xCF3, 0xCFF}, {0xD0D, 0xD0D}, {0xD11, 0xD11}, {0xD45, 0xD45}, {0xD49, 0xD49}, +{0xD50, 0xD53}, {0xD64, 0xD65}, {0xD80, 0xD80}, {0xD84, 0xD84}, {0xD97, 0xD99}, {0xDB2, 0xDB2}, {0xDBC, 0xDBC}, {0xDBE, 0xDBF}, {0xDC7, 0xDC9}, {0xDCB, 0xDCE}, {0xDD5, 0xDD5}, {0xDD7, 0xDD7}, +{0xDE0, 0xDE5}, {0xDF0, 0xDF1}, {0xDF5, 0xE00}, {0xE3B, 0xE3E}, {0xE5C, 0xE80}, {0xE83, 0xE83}, {0xE85, 0xE85}, {0xE8B, 0xE8B}, {0xEA4, 0xEA4}, {0xEA6, 0xEA6}, {0xEBE, 0xEBF}, {0xEC5, 0xEC5}, +{0xEC7, 0xEC7}, {0xECE, 0xECF}, {0xEDA, 0xEDB}, {0xEE0, 0xEFF}, {0xF48, 0xF48}, {0xF6D, 0xF70}, {0xF98, 0xF98}, {0xFBD, 0xFBD}, {0xFCD, 0xFCD}, {0xFDB, 0xFFF}, {0x10C6, 0x10C6}, {0x10C8, 0x10CC}, +{0x10CE, 0x10CF}, {0x1249, 0x1249}, {0x124E, 0x124F}, {0x1257, 0x1257}, {0x1259, 0x1259}, {0x125E, 0x125F}, {0x1289, 0x1289}, {0x128E, 0x128F}, {0x12B1, 0x12B1}, {0x12B6, 0x12B7}, {0x12BF, 0x12BF}, +{0x12C1, 0x12C1}, {0x12C6, 0x12C7}, {0x12D7, 0x12D7}, {0x1311, 0x1311}, {0x1316, 0x1317}, {0x135B, 0x135C}, {0x137D, 0x137F}, {0x139A, 0x139F}, {0x13F6, 0x13F7}, {0x13FE, 0x13FF}, {0x169D, 0x169F}, +{0x16F9, 0x16FF}, {0x170D, 0x170D}, {0x1715, 0x171F}, {0x1737, 0x173F}, {0x1754, 0x175F}, {0x176D, 0x176D}, {0x1771, 0x1771}, {0x1774, 0x177F}, {0x17DE, 0x17DF}, {0x17EA, 0x17EF}, {0x17FA, 0x17FF}, +{0x180E, 0x180F}, {0x181A, 0x181F}, {0x1879, 0x187F}, {0x18AB, 0x18AF}, {0x18F6, 0x18FF}, {0x191F, 0x191F}, {0x192C, 0x192F}, {0x193C, 0x193F}, {0x1941, 0x1943}, {0x196E, 0x196F}, {0x1975, 0x197F}, +{0x19AC, 0x19AF}, {0x19CA, 0x19CF}, {0x19DB, 0x19DD}, {0x1A1C, 0x1A1D}, {0x1A5F, 0x1A5F}, {0x1A7D, 0x1A7E}, {0x1A8A, 0x1A8F}, {0x1A9A, 0x1A9F}, {0x1AAE, 0x1AAF}, {0x1AC1, 0x1AFF}, {0x1B4C, 0x1B4F}, +{0x1B7D, 0x1B7F}, {0x1BF4, 0x1BFB}, {0x1C38, 0x1C3A}, {0x1C4A, 0x1C4C}, {0x1C89, 0x1C8F}, {0x1CBB, 0x1CBC}, {0x1CC8, 0x1CCF}, {0x1CFB, 0x1CFF}, {0x1DFA, 0x1DFA}, {0x1F16, 0x1F17}, {0x1F1E, 0x1F1F}, +{0x1F46, 0x1F47}, {0x1F4E, 0x1F4F}, {0x1F58, 0x1F58}, {0x1F5A, 0x1F5A}, {0x1F5C, 0x1F5C}, {0x1F5E, 0x1F5E}, {0x1F7E, 0x1F7F}, {0x1FB5, 0x1FB5}, {0x1FC5, 0x1FC5}, {0x1FD4, 0x1FD5}, {0x1FDC, 0x1FDC}, +{0x1FF0, 0x1FF1}, {0x1FF5, 0x1FF5}, {0x1FFF, 0x1FFF}, {0x200B, 0x200F}, {0x202A, 0x202E}, {0x2060, 0x206F}, {0x2072, 0x2073}, {0x208F, 0x208F}, {0x209D, 0x209F}, {0x20C0, 0x20CF}, {0x20F1, 0x20FF}, +{0x218C, 0x218F}, {0x2427, 0x243F}, {0x244B, 0x245F}, {0x2B74, 0x2B75}, {0x2B96, 0x2B96}, {0x2C2F, 0x2C2F}, {0x2C5F, 0x2C5F}, {0x2CF4, 0x2CF8}, {0x2D26, 0x2D26}, {0x2D28, 0x2D2C}, {0x2D2E, 0x2D2F}, +{0x2D68, 0x2D6E}, {0x2D71, 0x2D7E}, {0x2D97, 0x2D9F}, {0x2DA7, 0x2DA7}, {0x2DAF, 0x2DAF}, {0x2DB7, 0x2DB7}, {0x2DBF, 0x2DBF}, {0x2DC7, 0x2DC7}, {0x2DCF, 0x2DCF}, {0x2DD7, 0x2DD7}, {0x2DDF, 0x2DDF}, +{0x2E53, 0x2E7F}, {0x2E9A, 0x2E9A}, {0x2EF4, 0x2EFF}, {0x2FD6, 0x2FEF}, {0x2FFC, 0x2FFF}, {0x3040, 0x3040}, {0x3097, 0x3098}, {0x3100, 0x3104}, {0x3130, 0x3130}, {0x318F, 0x318F}, {0x31E4, 0x31EF}, +{0x321F, 0x321F}, {0x9FFD, 0x9FFF}, {0xA48D, 0xA48F}, {0xA4C7, 0xA4CF}, {0xA62C, 0xA63F}, {0xA6F8, 0xA6FF}, {0xA7C0, 0xA7C1}, {0xA7CB, 0xA7F4}, {0xA82D, 0xA82F}, {0xA83A, 0xA83F}, {0xA878, 0xA87F}, +{0xA8C6, 0xA8CD}, {0xA8DA, 0xA8DF}, {0xA954, 0xA95E}, {0xA97D, 0xA97F}, {0xA9CE, 0xA9CE}, {0xA9DA, 0xA9DD}, {0xA9FF, 0xA9FF}, {0xAA37, 0xAA3F}, {0xAA4E, 0xAA4F}, {0xAA5A, 0xAA5B}, {0xAAC3, 0xAADA}, +{0xAAF7, 0xAB00}, {0xAB07, 0xAB08}, {0xAB0F, 0xAB10}, {0xAB17, 0xAB1F}, {0xAB27, 0xAB27}, {0xAB2F, 0xAB2F}, {0xAB6C, 0xAB6F}, {0xABEE, 0xABEF}, {0xABFA, 0xABFF}, {0xD7A4, 0xD7AF}, {0xD7C7, 0xD7CA}, +{0xD7FC, 0xF8FF}, {0xFA6E, 0xFA6F}, {0xFADA, 0xFAFF}, {0xFB07, 0xFB12}, {0xFB18, 0xFB1C}, {0xFB37, 0xFB37}, {0xFB3D, 0xFB3D}, {0xFB3F, 0xFB3F}, {0xFB42, 0xFB42}, {0xFB45, 0xFB45}, {0xFBC2, 0xFBD2}, +{0xFD40, 0xFD4F}, {0xFD90, 0xFD91}, {0xFDC8, 0xFDEF}, {0xFDFE, 0xFDFF}, {0xFE1A, 0xFE1F}, {0xFE53, 0xFE53}, {0xFE67, 0xFE67}, {0xFE6C, 0xFE6F}, {0xFE75, 0xFE75}, {0xFEFD, 0xFF00}, {0xFFBF, 0xFFC1}, +{0xFFC8, 0xFFC9}, {0xFFD0, 0xFFD1}, {0xFFD8, 0xFFD9}, {0xFFDD, 0xFFDF}, {0xFFE7, 0xFFE7}, {0xFFEF, 0xFFFB}, {0xFFFE, 0xFFFF}, {0x1000C, 0x1000C}, {0x10027, 0x10027}, {0x1003B, 0x1003B}, +{0x1003E, 0x1003E}, {0x1004E, 0x1004F}, {0x1005E, 0x1007F}, {0x100FB, 0x100FF}, {0x10103, 0x10106}, {0x10134, 0x10136}, {0x1018F, 0x1018F}, {0x1019D, 0x1019F}, {0x101A1, 0x101CF}, {0x101FE, 0x1027F}, +{0x1029D, 0x1029F}, {0x102D1, 0x102DF}, {0x102FC, 0x102FF}, {0x10324, 0x1032C}, {0x1034B, 0x1034F}, {0x1037B, 0x1037F}, {0x1039E, 0x1039E}, {0x103C4, 0x103C7}, {0x103D6, 0x103FF}, {0x1049E, 0x1049F}, +{0x104AA, 0x104AF}, {0x104D4, 0x104D7}, {0x104FC, 0x104FF}, {0x10528, 0x1052F}, {0x10564, 0x1056E}, {0x10570, 0x105FF}, {0x10737, 0x1073F}, {0x10756, 0x1075F}, {0x10768, 0x107FF}, {0x10806, 0x10807}, +{0x10809, 0x10809}, {0x10836, 0x10836}, {0x10839, 0x1083B}, {0x1083D, 0x1083E}, {0x10856, 0x10856}, {0x1089F, 0x108A6}, {0x108B0, 0x108DF}, {0x108F3, 0x108F3}, {0x108F6, 0x108FA}, {0x1091C, 0x1091E}, +{0x1093A, 0x1093E}, {0x10940, 0x1097F}, {0x109B8, 0x109BB}, {0x109D0, 0x109D1}, {0x10A04, 0x10A04}, {0x10A07, 0x10A0B}, {0x10A14, 0x10A14}, {0x10A18, 0x10A18}, {0x10A36, 0x10A37}, {0x10A3B, 0x10A3E}, +{0x10A49, 0x10A4F}, {0x10A59, 0x10A5F}, {0x10AA0, 0x10ABF}, {0x10AE7, 0x10AEA}, {0x10AF7, 0x10AFF}, {0x10B36, 0x10B38}, {0x10B56, 0x10B57}, {0x10B73, 0x10B77}, {0x10B92, 0x10B98}, {0x10B9D, 0x10BA8}, +{0x10BB0, 0x10BFF}, {0x10C49, 0x10C7F}, {0x10CB3, 0x10CBF}, {0x10CF3, 0x10CF9}, {0x10D28, 0x10D2F}, {0x10D3A, 0x10E5F}, {0x10E7F, 0x10E7F}, {0x10EAA, 0x10EAA}, {0x10EAE, 0x10EAF}, {0x10EB2, 0x10EFF}, +{0x10F28, 0x10F2F}, {0x10F5A, 0x10FAF}, {0x10FCC, 0x10FDF}, {0x10FF7, 0x10FFF}, {0x1104E, 0x11051}, {0x11070, 0x1107E}, {0x110BD, 0x110BD}, {0x110C2, 0x110CF}, {0x110E9, 0x110EF}, {0x110FA, 0x110FF}, +{0x11135, 0x11135}, {0x11148, 0x1114F}, {0x11177, 0x1117F}, {0x111E0, 0x111E0}, {0x111F5, 0x111FF}, {0x11212, 0x11212}, {0x1123F, 0x1127F}, {0x11287, 0x11287}, {0x11289, 0x11289}, {0x1128E, 0x1128E}, +{0x1129E, 0x1129E}, {0x112AA, 0x112AF}, {0x112EB, 0x112EF}, {0x112FA, 0x112FF}, {0x11304, 0x11304}, {0x1130D, 0x1130E}, {0x11311, 0x11312}, {0x11329, 0x11329}, {0x11331, 0x11331}, {0x11334, 0x11334}, +{0x1133A, 0x1133A}, {0x11345, 0x11346}, {0x11349, 0x1134A}, {0x1134E, 0x1134F}, {0x11351, 0x11356}, {0x11358, 0x1135C}, {0x11364, 0x11365}, {0x1136D, 0x1136F}, {0x11375, 0x113FF}, {0x1145C, 0x1145C}, +{0x11462, 0x1147F}, {0x114C8, 0x114CF}, {0x114DA, 0x1157F}, {0x115B6, 0x115B7}, {0x115DE, 0x115FF}, {0x11645, 0x1164F}, {0x1165A, 0x1165F}, {0x1166D, 0x1167F}, {0x116B9, 0x116BF}, {0x116CA, 0x116FF}, +{0x1171B, 0x1171C}, {0x1172C, 0x1172F}, {0x11740, 0x117FF}, {0x1183C, 0x1189F}, {0x118F3, 0x118FE}, {0x11907, 0x11908}, {0x1190A, 0x1190B}, {0x11914, 0x11914}, {0x11917, 0x11917}, {0x11936, 0x11936}, +{0x11939, 0x1193A}, {0x11947, 0x1194F}, {0x1195A, 0x1199F}, {0x119A8, 0x119A9}, {0x119D8, 0x119D9}, {0x119E5, 0x119FF}, {0x11A48, 0x11A4F}, {0x11AA3, 0x11ABF}, {0x11AF9, 0x11BFF}, {0x11C09, 0x11C09}, +{0x11C37, 0x11C37}, {0x11C46, 0x11C4F}, {0x11C6D, 0x11C6F}, {0x11C90, 0x11C91}, {0x11CA8, 0x11CA8}, {0x11CB7, 0x11CFF}, {0x11D07, 0x11D07}, {0x11D0A, 0x11D0A}, {0x11D37, 0x11D39}, {0x11D3B, 0x11D3B}, +{0x11D3E, 0x11D3E}, {0x11D48, 0x11D4F}, {0x11D5A, 0x11D5F}, {0x11D66, 0x11D66}, {0x11D69, 0x11D69}, {0x11D8F, 0x11D8F}, {0x11D92, 0x11D92}, {0x11D99, 0x11D9F}, {0x11DAA, 0x11EDF}, {0x11EF9, 0x11FAF}, +{0x11FB1, 0x11FBF}, {0x11FF2, 0x11FFE}, {0x1239A, 0x123FF}, {0x1246F, 0x1246F}, {0x12475, 0x1247F}, {0x12544, 0x12FFF}, {0x1342F, 0x143FF}, {0x14647, 0x167FF}, {0x16A39, 0x16A3F}, {0x16A5F, 0x16A5F}, +{0x16A6A, 0x16A6D}, {0x16A70, 0x16ACF}, {0x16AEE, 0x16AEF}, {0x16AF6, 0x16AFF}, {0x16B46, 0x16B4F}, {0x16B5A, 0x16B5A}, {0x16B62, 0x16B62}, {0x16B78, 0x16B7C}, {0x16B90, 0x16E3F}, {0x16E9B, 0x16EFF}, +{0x16F4B, 0x16F4E}, {0x16F88, 0x16F8E}, {0x16FA0, 0x16FDF}, {0x16FE5, 0x16FEF}, {0x16FF2, 0x16FFF}, {0x187F8, 0x187FF}, {0x18CD6, 0x18CFF}, {0x18D09, 0x1AFFF}, {0x1B11F, 0x1B14F}, {0x1B153, 0x1B163}, +{0x1B168, 0x1B16F}, {0x1B2FC, 0x1BBFF}, {0x1BC6B, 0x1BC6F}, {0x1BC7D, 0x1BC7F}, {0x1BC89, 0x1BC8F}, {0x1BC9A, 0x1BC9B}, {0x1BCA0, 0x1CFFF}, {0x1D0F6, 0x1D0FF}, {0x1D127, 0x1D128}, {0x1D173, 0x1D17A}, +{0x1D1E9, 0x1D1FF}, {0x1D246, 0x1D2DF}, {0x1D2F4, 0x1D2FF}, {0x1D357, 0x1D35F}, {0x1D379, 0x1D3FF}, {0x1D455, 0x1D455}, {0x1D49D, 0x1D49D}, {0x1D4A0, 0x1D4A1}, {0x1D4A3, 0x1D4A4}, {0x1D4A7, 0x1D4A8}, +{0x1D4AD, 0x1D4AD}, {0x1D4BA, 0x1D4BA}, {0x1D4BC, 0x1D4BC}, {0x1D4C4, 0x1D4C4}, {0x1D506, 0x1D506}, {0x1D50B, 0x1D50C}, {0x1D515, 0x1D515}, {0x1D51D, 0x1D51D}, {0x1D53A, 0x1D53A}, {0x1D53F, 0x1D53F}, +{0x1D545, 0x1D545}, {0x1D547, 0x1D549}, {0x1D551, 0x1D551}, {0x1D6A6, 0x1D6A7}, {0x1D7CC, 0x1D7CD}, {0x1DA8C, 0x1DA9A}, {0x1DAA0, 0x1DAA0}, {0x1DAB0, 0x1DFFF}, {0x1E007, 0x1E007}, {0x1E019, 0x1E01A}, +{0x1E022, 0x1E022}, {0x1E025, 0x1E025}, {0x1E02B, 0x1E0FF}, {0x1E12D, 0x1E12F}, {0x1E13E, 0x1E13F}, {0x1E14A, 0x1E14D}, {0x1E150, 0x1E2BF}, {0x1E2FA, 0x1E2FE}, {0x1E300, 0x1E7FF}, {0x1E8C5, 0x1E8C6}, +{0x1E8D7, 0x1E8FF}, {0x1E94C, 0x1E94F}, {0x1E95A, 0x1E95D}, {0x1E960, 0x1EC70}, {0x1ECB5, 0x1ED00}, {0x1ED3E, 0x1EDFF}, {0x1EE04, 0x1EE04}, {0x1EE20, 0x1EE20}, {0x1EE23, 0x1EE23}, {0x1EE25, 0x1EE26}, +{0x1EE28, 0x1EE28}, {0x1EE33, 0x1EE33}, {0x1EE38, 0x1EE38}, {0x1EE3A, 0x1EE3A}, {0x1EE3C, 0x1EE41}, {0x1EE43, 0x1EE46}, {0x1EE48, 0x1EE48}, {0x1EE4A, 0x1EE4A}, {0x1EE4C, 0x1EE4C}, {0x1EE50, 0x1EE50}, +{0x1EE53, 0x1EE53}, {0x1EE55, 0x1EE56}, {0x1EE58, 0x1EE58}, {0x1EE5A, 0x1EE5A}, {0x1EE5C, 0x1EE5C}, {0x1EE5E, 0x1EE5E}, {0x1EE60, 0x1EE60}, {0x1EE63, 0x1EE63}, {0x1EE65, 0x1EE66}, {0x1EE6B, 0x1EE6B}, +{0x1EE73, 0x1EE73}, {0x1EE78, 0x1EE78}, {0x1EE7D, 0x1EE7D}, {0x1EE7F, 0x1EE7F}, {0x1EE8A, 0x1EE8A}, {0x1EE9C, 0x1EEA0}, {0x1EEA4, 0x1EEA4}, {0x1EEAA, 0x1EEAA}, {0x1EEBC, 0x1EEEF}, {0x1EEF2, 0x1EFFF}, +{0x1F02C, 0x1F02F}, {0x1F094, 0x1F09F}, {0x1F0AF, 0x1F0B0}, {0x1F0C0, 0x1F0C0}, {0x1F0D0, 0x1F0D0}, {0x1F0F6, 0x1F0FF}, {0x1F1AE, 0x1F1E5}, {0x1F203, 0x1F20F}, {0x1F23C, 0x1F23F}, {0x1F249, 0x1F24F}, +{0x1F252, 0x1F25F}, {0x1F266, 0x1F2FF}, {0x1F6D8, 0x1F6DF}, {0x1F6ED, 0x1F6EF}, {0x1F6FD, 0x1F6FF}, {0x1F774, 0x1F77F}, {0x1F7D9, 0x1F7DF}, {0x1F7EC, 0x1F7FF}, {0x1F80C, 0x1F80F}, {0x1F848, 0x1F84F}, +{0x1F85A, 0x1F85F}, {0x1F888, 0x1F88F}, {0x1F8AE, 0x1F8AF}, {0x1F8B2, 0x1F8FF}, {0x1F979, 0x1F979}, {0x1F9CC, 0x1F9CC}, {0x1FA54, 0x1FA5F}, {0x1FA6E, 0x1FA6F}, {0x1FA75, 0x1FA77}, {0x1FA7B, 0x1FA7F}, +{0x1FA87, 0x1FA8F}, {0x1FAA9, 0x1FAAF}, {0x1FAB7, 0x1FABF}, {0x1FAC3, 0x1FACF}, {0x1FAD7, 0x1FAFF}, {0x1FB93, 0x1FB93}, {0x1FBCB, 0x1FBEF}, {0x1FBFA, 0x1FFFF}, {0x2A6DE, 0x2A6FF}, {0x2B735, 0x2B73F}, +{0x2B81E, 0x2B81F}, {0x2CEA2, 0x2CEAF}, {0x2EBE1, 0x2F7FF}, {0x2FA1E, 0x2FFFF}, {0x3134B, 0xE00FF}, {0xE01F0, 0x10FFFF}, +}; + +static std::string codepoint_to_utf8(uint32_t cp) { + std::string result; + if (/* 0x00 <= cp && */ cp <= 0x7f) { + result.push_back(cp); + } + else if (0x80 <= cp && cp <= 0x7ff) { + result.push_back(0xc0 | ((cp >> 6) & 0x1f)); + result.push_back(0x80 | (cp & 0x3f)); + } + else if (0x800 <= cp && cp <= 0xffff) { + result.push_back(0xe0 | ((cp >> 12) & 0x0f)); + result.push_back(0x80 | ((cp >> 6) & 0x3f)); + result.push_back(0x80 | (cp & 0x3f)); + } + else if (0x10000 <= cp && cp <= 0x10ffff) { + result.push_back(0xf0 | ((cp >> 18) & 0x07)); + result.push_back(0x80 | ((cp >> 12) & 0x3f)); + result.push_back(0x80 | ((cp >> 6) & 0x3f)); + result.push_back(0x80 | (cp & 0x3f)); + } + else { + throw std::invalid_argument("invalid codepoint"); + } + return result; +} + +static std::string codepoints_to_utf8(const std::vector & cps) { + std::string result; + for (size_t i = 0; i < cps.size(); ++i) { + result.append(codepoint_to_utf8(cps[i])); + } + return result; +} + +static uint32_t codepoint_from_utf8(const std::string & utf8, size_t & offset) { + assert(offset < utf8.size()); + if (!(utf8[offset + 0] & 0x80)) { + auto result = utf8[offset + 0]; + offset += 1; + return result; + } + else if (!(utf8[offset + 0] & 0x40)) { + throw std::invalid_argument("invalid character"); + } + else if (!(utf8[offset + 0] & 0x20)) { + if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) + throw std::invalid_argument("invalid character"); + auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f); + offset += 2; + return result; + } + else if (!(utf8[offset + 0] & 0x10)) { + if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) + throw std::invalid_argument("invalid character"); + auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f); + offset += 3; + return result; + } + else if (!(utf8[offset + 0] & 0x08)) { + if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) + throw std::invalid_argument("invalid character"); + auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f); + offset += 4; + return result; + } + throw std::invalid_argument("invalid string"); +} + +static std::vector codepoints_from_utf8(const std::string & utf8) { + std::vector result; + size_t offset = 0; + while (offset < utf8.size()) { + result.push_back(codepoint_from_utf8(utf8, offset)); + } + return result; +} + +static std::vector codepoint_to_utf16(uint32_t cp) { + std::vector result; + if (/* 0x0000 <= cp && */ cp <= 0xffff) { + result.emplace_back(cp); + } + else if (0x10000 <= cp && cp <= 0x10ffff) { + result.emplace_back(0xd800 | ((cp - 0x10000) >> 10)); + result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff)); + } + else { + throw std::invalid_argument("invalid codepoint"); + } + return result; +} + +static std::vector codepoints_to_utf16(const std::vector & cps) { + std::vector result; + for (size_t i = 0; i < cps.size(); ++i) { + auto temp = codepoint_to_utf16(cps[i]); + result.insert(result.end(), temp.begin(), temp.end()); + } + return result; +} + +static uint32_t codepoint_from_utf16(const std::vector & utf16, size_t & offset) { + assert(offset < utf16.size()); + if (((utf16[0] >> 10) << 10) != 0xd800) { + auto result = utf16[offset + 0]; + offset += 1; + return result; + } + else { + if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) + throw std::invalid_argument("invalid character"); + auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff)); + offset += 2; + return result; + } + throw std::invalid_argument("invalid string"); +} + +static std::vector codepoints_from_utf16(const std::vector & utf16) { + std::vector result; + size_t offset = 0; + while (offset < utf16.size()) + result.push_back(codepoint_from_utf16(utf16, offset)); + return result; +} + +#define CODEPOINT_TYPE_UNIDENTIFIED 0 +#define CODEPOINT_TYPE_DIGIT 1 +#define CODEPOINT_TYPE_LETTER 2 +#define CODEPOINT_TYPE_WHITESPACE 3 +#define CODEPOINT_TYPE_ACCENT_MARK 4 +#define CODEPOINT_TYPE_PUNCTUATION 5 +#define CODEPOINT_TYPE_SYMBOL 6 +#define CODEPOINT_TYPE_CONTROL 7 + +static std::unordered_map codepoint_type_map() { + std::unordered_map codepoint_types; + for (auto p : digit_ranges) { + for(auto i = p.first; i <= p.second; ++ i) + codepoint_types[i] = CODEPOINT_TYPE_DIGIT; + } + for(auto p : letter_ranges) { + for(auto i = p.first; i <= p.second; ++ i) + codepoint_types[i] = CODEPOINT_TYPE_LETTER; + } + for(auto p : whitespace_ranges) { + for(auto i = p.first; i <= p.second; ++ i) + codepoint_types[i] = CODEPOINT_TYPE_WHITESPACE; + } + for(auto p : accent_mark_ranges) { + for(auto i = p.first; i <= p.second; ++ i) + codepoint_types[i] = CODEPOINT_TYPE_ACCENT_MARK; + } + for(auto p : punctuation_ranges) { + for(auto i = p.first; i <= p.second; ++ i) + codepoint_types[i] = CODEPOINT_TYPE_PUNCTUATION; + } + for (auto p : symbol_ranges) { + for (auto i = p.first; i <= p.second; ++i) + codepoint_types[i] = CODEPOINT_TYPE_SYMBOL; + } + for(auto p : control_ranges) { + for(auto i = p.first; i <= p.second; ++ i) + codepoint_types[i] = CODEPOINT_TYPE_CONTROL; + } + return codepoint_types; +} + +static int codepoint_type(uint32_t cp) { + static std::unordered_map codepoint_types = codepoint_type_map(); + return codepoint_types[cp]; +} + +static int codepoint_type(const std::string & utf8) { + if (utf8.length() == 0) + return CODEPOINT_TYPE_UNIDENTIFIED; + size_t offset = 0; + return codepoint_type(codepoint_from_utf8(utf8, offset)); +} + +static std::unordered_map bytes_to_unicode_map_bpe() { + std::unordered_map map; + for (int ch = u'!'; ch <= u'~'; ++ch) { + assert(0 <= ch && ch < 256); + map[ch] = codepoint_to_utf8(ch); + } + for (int ch = u'¡'; ch <= u'¬'; ++ch) { + assert(0 <= ch && ch < 256); + map[ch] = codepoint_to_utf8(ch); + } + for (int ch = u'®'; ch <= u'ÿ'; ++ch) { + assert(0 <= ch && ch < 256); + map[ch] = codepoint_to_utf8(ch); + } + auto n = 0; + for (int ch = 0; ch < 256; ++ch) { + if (map.find(ch) == map.end()) { + map[ch] = codepoint_to_utf8(256 + n); + ++n; + } + } + return map; +} + +static std::string bytes_to_unicode_bpe(uint8_t byte) { + static std::unordered_map map = bytes_to_unicode_map_bpe(); + return map.at(byte); +} + +static std::unordered_map unicode_to_bytes_map_bpe() { + std::unordered_map map; + for (int ch = u'!'; ch <= u'~'; ++ch) { + assert(0 <= ch && ch < 256); + map[codepoint_to_utf8(ch)] = ch; + } + for (int ch = u'¡'; ch <= u'¬'; ++ch) { + assert(0 <= ch && ch < 256); + map[codepoint_to_utf8(ch)] = ch; + } + for (int ch = u'®'; ch <= u'ÿ'; ++ch) { + assert(0 <= ch && ch < 256); + map[codepoint_to_utf8(ch)] = ch; + } + auto n = 0; + for (int ch = 0; ch < 256; ++ch) { + if (map.find(codepoint_to_utf8(ch)) == map.end()) { + map[codepoint_to_utf8(256 + n)] = ch; + ++n; + } + } + return map; +} + +static uint8_t unicode_to_bytes_bpe(const std::string & utf8) { + static std::unordered_map map = unicode_to_bytes_map_bpe(); + return map.at(utf8); +} + diff --git a/applications/src/llama/version b/applications/src/llama/version new file mode 100644 index 000000000..135f41c07 --- /dev/null +++ b/applications/src/llama/version @@ -0,0 +1 @@ +Commit: 3c04bf6da89eaf4c7d317e0518f0687dfcbf2de7 \ No newline at end of file diff --git a/applications/test_applications.py b/applications/test_applications.py new file mode 100644 index 000000000..38fbb3ef9 --- /dev/null +++ b/applications/test_applications.py @@ -0,0 +1,95 @@ +# ===------ test_applications.py ------------------------ *- Python -* ----=== # +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# ===----------------------------------------------------------------------=== # + +import os +import re +import sys +from pathlib import Path +parent = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(parent) + +from test_utils import * + +def setup_test(): + + return True + +def migrate_test(): + src = [] + extra_args = [] + in_root = os.path.join(os.getcwd(), test_config.current_test) + test_config.out_root = os.path.join(in_root, 'out_root') + + if test_config.current_test != 'llama': + for dirpath, dirnames, filenames in os.walk(in_root): + for filename in [f for f in filenames if re.match('.*(cu|cpp|c)$', f)]: + src.append(os.path.abspath(os.path.join(dirpath, filename))) + + if test_config.current_test == 'llama': + original_string = '/export/users/placeholder/project/llama.cpp' + new_string = in_root + db_dir = os.path.join(in_root, 'gpubuild') + with open(os.path.join(db_dir, 'compile_commands.json'), "r") as file: + file_contents = file.read() + file_contents = file_contents.replace(original_string, new_string) + db_dir = db_dir.replace('\\', '\\\\') + file_contents = re.sub(r'"directory": ".*?"', '"directory": "' + db_dir + '"', file_contents) + file_contents = file_contents.replace('\\', '\\\\') + with open(os.path.join(db_dir, 'compile_commands.json'), "w") as file: + file.write(file_contents) + src.append(' -p=' + os.path.join(in_root, 'gpubuild')) + src.append(' --enable-profiling ') + src.append(' --use-experimental-features=free-function-queries,local-memory-kernel-scope-allocation ') + + return do_migrate(src, in_root, test_config.out_root, extra_args) + +def build_test(): + if (os.path.exists(test_config.current_test)): + os.chdir(test_config.current_test) + srcs = [] + cmp_opts = [] + link_opts = [] + objects = [] + + llama_files = ['common.cpp', 'sampling.cpp.dp.cpp', 'console.cpp', 'grammar-parser.cpp', 'train.cpp.dp.cpp', + 'build-info.cpp', 'llama.cpp', 'ggml.c', 'ggml-alloc.c', 'ggml-backend.c', 'ggml-quants.c', + 'ggml-cuda.dp.cpp', 'main.cpp'] + + if test_config.current_test == 'llama': + if platform.system() == 'Linux': + link_opts = test_config.mkl_link_opt_lin + else: + link_opts = test_config.mkl_link_opt_win + cmp_opts.append("-DMKL_ILP64") + cmp_opts.append("-DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128") + cmp_opts.append("-DGGML_USE_CUBLAS -DK_QUANTS_PER_ITERATION=2 -D_GNU_SOURCE -D_XOPEN_SOURCE=600 -O3 -DNDEBUG") + cmp_opts.append("-I " + test_config.out_root) + cmp_opts.append("-I " + os.path.join(test_config.out_root, 'common')) + + for dirpath, dirnames, filenames in os.walk(test_config.out_root): + if test_config.current_test == 'llama': + for filename in filenames: + if filename in llama_files: + srcs.append(os.path.abspath(os.path.join(dirpath, filename))) + else: + for filename in [f for f in filenames if re.match('.*(cpp|c)$', f)]: + srcs.append(os.path.abspath(os.path.join(dirpath, filename))) + + if platform.system() == 'Linux': + link_opts.append(' -lpthread ') + + ret = False + + if test_config.current_test == 'llama': + ret = compile_and_link(srcs, cmp_opts, objects, link_opts) + else: + ret = compile_files(srcs, cmp_opts) + return ret + +def run_test(): + return True \ No newline at end of file diff --git a/applications/test_drivers.xml b/applications/test_drivers.xml new file mode 100644 index 000000000..70d43b428 --- /dev/null +++ b/applications/test_drivers.xml @@ -0,0 +1,6 @@ + + + + APP test suite. + + diff --git a/applications/timelimit.xml b/applications/timelimit.xml new file mode 100644 index 000000000..701ca9e76 --- /dev/null +++ b/applications/timelimit.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/behavior_tests/behavior_tests.xml b/behavior_tests/behavior_tests.xml index f95bb9bf8..8c2f0071e 100644 --- a/behavior_tests/behavior_tests.xml +++ b/behavior_tests/behavior_tests.xml @@ -80,8 +80,8 @@ - - + + @@ -116,7 +116,6 @@ - @@ -125,7 +124,6 @@ - @@ -168,6 +166,7 @@ + diff --git a/behavior_tests/src/analysis-mode/do_test.py b/behavior_tests/src/analysis-mode/do_test.py new file mode 100644 index 000000000..517eee254 --- /dev/null +++ b/behavior_tests/src/analysis-mode/do_test.py @@ -0,0 +1,62 @@ +# ====------ do_test.py---------- *- Python -* ----===## +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# +# ===----------------------------------------------------------------------===# +import subprocess +import platform +import os +import sys + +from test_utils import * + + +def setup_test(): + change_dir(test_config.current_test) + return True + + +def migrate_test(): + call_subprocess( + test_config.CT_TOOL + " test.cu -analysis-mode --cuda-include-path=" + test_config.include_path) + + + reference_lines = [ + ["test.cu:"], + ["lines of code", "will be automatically migrated."], + ["APIs/Types - No manual effort."], + ["APIs/Types - Low manual effort for checking and code fixing."], + ["APIs/Types - Medium manual effort for code fixing."], + ["lines of code", "will not be automatically migrated."], + ["APIs/Types - High manual effort for code fixing."], + ["Total Project:"], + ["lines of code", "will be automatically migrated."], + ["APIs/Types - No manual effort."], + ["APIs/Types - Low manual effort for checking and code fixing."], + ["APIs/Types - Medium manual effort for code fixing."], + ["lines of code", "will not be automatically migrated."], + ["APIs/Types - High manual effort for code fixing."]] + + ret_lines = test_config.command_output.splitlines() + res = True + idx = 5 + for refs in reference_lines: + ret = ret_lines[idx] + for ref in refs: + if ref not in ret: + res = False + print("there should be a '" + ref + "' in output.") + idx = idx + 1 + + return res + + +def build_test(): + return True + + +def run_test(): + return True diff --git a/behavior_tests/src/analysis-mode/test.cu b/behavior_tests/src/analysis-mode/test.cu new file mode 100644 index 000000000..9b85d5bba --- /dev/null +++ b/behavior_tests/src/analysis-mode/test.cu @@ -0,0 +1,22 @@ +// ====------ test.cu---------- *- CUDA -* ----===//// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// +// ===----------------------------------------------------------------------===// + +__global__ void kernel(int *a) { + *a = clock64(); + __syncthreads(); + int b = *a; +} + +void foo() { + int *a; + size_t b, c; + cudaMemGetInfo(&b, &c); + cudaMalloc(&a, sizeof(int)); + cudaFree(a); +} \ No newline at end of file diff --git a/behavior_tests/src/bt-autocomplete/do_test.py b/behavior_tests/src/bt-autocomplete/do_test.py index f8ead03fa..532965749 100644 --- a/behavior_tests/src/bt-autocomplete/do_test.py +++ b/behavior_tests/src/bt-autocomplete/do_test.py @@ -55,10 +55,14 @@ def migrate_test(): 'silent\n' res = res and (reference == test_config.command_output) - call_subprocess(test_config.CT_TOOL + " --autocomplete=foo#bar##--enable-c") + call_subprocess(test_config.CT_TOOL + " --autocomplete=foo#bar##--enable-ct") reference = '--enable-ctad\n' res = res and (reference == test_config.command_output) + call_subprocess(test_config.CT_TOOL + " --autocomplete=foo#bar##--enable-code") + reference = '--enable-codepin\n' + res = res and (reference == test_config.command_output) + call_subprocess(test_config.CT_TOOL + " --autocomplete=foo#bar###--format-range=#a") reference = 'all\n' res = res and (reference == test_config.command_output) @@ -80,17 +84,16 @@ def migrate_test(): '-process-all\n' res = res and (reference == test_config.command_output) - call_subprocess(test_config.CT_TOOL + " --autocomplete=--usm-level=#none,restricted#--use-explicit-namespace=#cl,sycl,") - reference = 'cl,sycl,cl\n' + \ - 'cl,sycl,dpct\n' + \ - 'cl,sycl,none\n' + \ - 'cl,sycl,sycl\n' + \ - 'cl,sycl,sycl-math\n' + call_subprocess(test_config.CT_TOOL + " --autocomplete=--usm-level=#none,restricted#--use-explicit-namespace=#sycl,") + reference = 'sycl,dpct\n' + \ + 'sycl,none\n' + \ + 'sycl,sycl\n' + \ + 'sycl,sycl-math\n' res = res and (reference == test_config.command_output) - call_subprocess(test_config.CT_TOOL + " --autocomplete=--usm-level=#none,restricted#--use-explicit-namespace=#cl,sycl,s") - reference = 'cl,sycl,sycl\n' + \ - 'cl,sycl,sycl-math\n' + call_subprocess(test_config.CT_TOOL + " --autocomplete=--usm-level=#none,restricted#--use-explicit-namespace=#sycl,s") + reference = 'sycl,sycl\n' + \ + 'sycl,sycl-math\n' res = res and (reference == test_config.command_output) call_subprocess(test_config.CT_TOOL + " --autocomplete=") @@ -127,7 +130,6 @@ def migrate_test(): '--in-root\n', '--in-root-exclude\n', '--keep-original-code\n', - '--no-cl-namespace-inline\n', '--no-dpcpp-extensions=\n', '--no-dry-pattern\n', '--no-incremental-migration\n', @@ -168,4 +170,4 @@ def build_test(): return True def run_test(): - return True \ No newline at end of file + return True diff --git a/behavior_tests/src/bt-use-explicit-namespace4/do_test.py b/behavior_tests/src/bt-use-explicit-namespace4/do_test.py deleted file mode 100644 index 73f88d467..000000000 --- a/behavior_tests/src/bt-use-explicit-namespace4/do_test.py +++ /dev/null @@ -1,27 +0,0 @@ -# ====------ do_test.py---------- *- Python -* ----===## -# -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# -# -# ===----------------------------------------------------------------------===# -import subprocess -import platform -import os -import sys - -from test_utils import * - -def setup_test(): - change_dir(test_config.current_test) - return True - -def migrate_test(): - ret = call_subprocess(test_config.CT_TOOL + " --use-explicit-namespace=cl --out-root=./sycl vector_add.cu --cuda-include-path=" + test_config.include_path) - return ret -def build_test(): - return True - -def run_test(): - return True diff --git a/behavior_tests/src/bt-use-explicit-namespace4/vector_add.cu b/behavior_tests/src/bt-use-explicit-namespace4/vector_add.cu deleted file mode 100644 index 3d329b36a..000000000 --- a/behavior_tests/src/bt-use-explicit-namespace4/vector_add.cu +++ /dev/null @@ -1,46 +0,0 @@ -// ====------ vector_add.cu---------- *- CUDA -* ----===//// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// -// ===----------------------------------------------------------------------===// - -#include -#include -#define VECTOR_SIZE 256 - -__global__ void VectorAddKernel(float* A, float* B, float* C) -{ - A[threadIdx.x] = threadIdx.x + 1.0f; - B[threadIdx.x] = threadIdx.x + 1.0f; - C[threadIdx.x] = expf(A[threadIdx.x] + B[threadIdx.x]); -} - -int main() -{ - float *d_A, *d_B, *d_C; - - cudaMalloc(&d_A, VECTOR_SIZE*sizeof(float)); - cudaMalloc(&d_B, VECTOR_SIZE*sizeof(float)); - cudaMalloc(&d_C, VECTOR_SIZE*sizeof(float)); - - VectorAddKernel<<<1, VECTOR_SIZE>>>(d_A, d_B, d_C); - - float Result[VECTOR_SIZE] = { }; - cudaMemcpy(Result, d_C, VECTOR_SIZE*sizeof(float), cudaMemcpyDeviceToHost); - - cudaFree(d_A); - cudaFree(d_B); - cudaFree(d_C); - - for (int i = 0; i < VECTOR_SIZE; i++) { - if (i % 16 == 0) { - printf("\n"); - } - printf("%f ", Result[i]); - } - - return 0; -} diff --git a/behavior_tests/src/bt-yaml-with-different-options1/do_test.py b/behavior_tests/src/bt-yaml-with-different-options1/do_test.py index 6b04f8faf..55b629e57 100644 --- a/behavior_tests/src/bt-yaml-with-different-options1/do_test.py +++ b/behavior_tests/src/bt-yaml-with-different-options1/do_test.py @@ -18,11 +18,11 @@ def setup_test(): return True def migrate_test(): - call_subprocess(test_config.CT_TOOL + " test.cu --in-root . --out-root out --always-use-async-handler --assume-nd-range-dim=1 --comments --enable-ctad --no-dpcpp-extensions=enqueued_barriers --no-dry-pattern --process-all -p . --sycl-named-lambda --use-experimental-features=free-function-queries,nd_range_barrier --use-explicit-namespace=cl,dpct --usm-level=none --cuda-include-path=" + test_config.include_path) + call_subprocess(test_config.CT_TOOL + " test.cu --in-root . --out-root out --always-use-async-handler --assume-nd-range-dim=1 --comments --enable-ctad --no-dpcpp-extensions=enqueued_barriers --no-dry-pattern --process-all -p . --sycl-named-lambda --use-experimental-features=free-function-queries,nd_range_barrier --use-explicit-namespace=sycl,dpct --usm-level=none --cuda-include-path=" + test_config.include_path) call_subprocess(test_config.CT_TOOL + " test.cu --out-root out --cuda-include-path=" + test_config.include_path) return is_sub_string("\"--analysis-scope-path=\"", test_config.command_output) and \ is_sub_string("--always-use-async-handler --comments --compilation-database=\"", test_config.command_output) and \ - is_sub_string("--enable-ctad --use-experimental-features=free-function-queries,nd_range_barrier --use-explicit-namespace=cl,dpct --no-dpcpp-extensions=enqueued_barriers --assume-nd-range-dim=1 --no-dry-pattern --process-all --sycl-named-lambda --usm-level=none\".", test_config.command_output) + is_sub_string("--enable-ctad --use-experimental-features=free-function-queries,nd_range_barrier --use-explicit-namespace=sycl,dpct --no-dpcpp-extensions=enqueued_barriers --assume-nd-range-dim=1 --no-dry-pattern --process-all --sycl-named-lambda --usm-level=none\".", test_config.command_output) def build_test(): return True diff --git a/behavior_tests/src/bt-yaml-with-different-options2/do_test.py b/behavior_tests/src/bt-yaml-with-different-options2/do_test.py deleted file mode 100644 index 68cfe7139..000000000 --- a/behavior_tests/src/bt-yaml-with-different-options2/do_test.py +++ /dev/null @@ -1,29 +0,0 @@ -# ====------ do_test.py---------- *- Python -* ----===## -# -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# -# -# ===----------------------------------------------------------------------===# -import subprocess -import platform -import os -import sys - -from test_utils import * - -def setup_test(): - change_dir(test_config.current_test) - return True - -def migrate_test(): - call_subprocess(test_config.CT_TOOL + " test.cu --out-root out --no-cl-namespace-inline --cuda-include-path=" + test_config.include_path) - call_subprocess(test_config.CT_TOOL + " test.cu --out-root out --cuda-include-path=" + test_config.include_path) - return is_sub_string("--no-cl-namespace-inline\".", test_config.command_output) - -def build_test(): - return True - -def run_test(): - return True \ No newline at end of file diff --git a/behavior_tests/src/ngt-invalid-explicit-namespace2/do_test.py b/behavior_tests/src/ngt-invalid-explicit-namespace2/do_test.py index 0522d4e87..3fd65f673 100644 --- a/behavior_tests/src/ngt-invalid-explicit-namespace2/do_test.py +++ b/behavior_tests/src/ngt-invalid-explicit-namespace2/do_test.py @@ -18,7 +18,7 @@ def setup_test(): return True def migrate_test(): - call_subprocess(test_config.CT_TOOL + " --use-explicit-namespace=sycl,cl,dpct vector_add.cu --cuda-include-path=" + test_config.include_path) + call_subprocess(test_config.CT_TOOL + " --use-explicit-namespace=sycl,dpct,none vector_add.cu --cuda-include-path=" + test_config.include_path) return is_sub_string("Error: The input for option --use-explicit-namespace is not valid.", test_config.command_output) diff --git a/behavior_tests/src/ngt-invalid-json7/compile_commands.json b/behavior_tests/src/ngt-invalid-json7/compile_commands.json index f860c2904..46bf0c975 100644 --- a/behavior_tests/src/ngt-invalid-json7/compile_commands.json +++ b/behavior_tests/src/ngt-invalid-json7/compile_commands.json @@ -2,6 +2,9 @@ { "command": "nvcc hello1.c", "directory": "directory_placeholder", - "file": "hello1.c" + "file": "hello1.c", + "command": "nvcc hello2.c", + "directory": "directory_placeholder", + "file": "hello2.c" } ] diff --git a/behavior_tests/src/ngt-invalid-json7/do_test.py b/behavior_tests/src/ngt-invalid-json7/do_test.py index 09b55d9e1..e051635d1 100644 --- a/behavior_tests/src/ngt-invalid-json7/do_test.py +++ b/behavior_tests/src/ngt-invalid-json7/do_test.py @@ -21,18 +21,22 @@ def setup_test(): def migrate_test(): data = [] ret = [] + iter = 0 with open("compile_commands.json", 'r') as f: data = f.readlines() for line in data: + if iter == 1: + line = line.replace("directory_placeholder", os.getcwd().replace("\\", "\\\\")) if "directory_placeholder" in line: - ret.append(" \"output\": \"aaaaa\",\n") - line = line.replace("directory_placeholder", os.getcwd().replace("\\", "\\\\")) + iter += 1 + ret.append(line) with open("compile_commands.json", 'w') as f: f.writelines(ret) call_subprocess(test_config.CT_TOOL + ' -p=./ --cuda-include-path=' + test_config.include_path) - return is_sub_string("Unknown key: \"\"output\"\"", test_config.command_output) + return is_sub_string("Processed 1 file(s)", test_config.command_output) + def build_test(): return True @@ -40,3 +44,4 @@ def build_test(): def run_test(): return True + diff --git a/behavior_tests/src/ngt-invalid-json9/hello2.c b/behavior_tests/src/ngt-invalid-json7/hello2.c similarity index 100% rename from behavior_tests/src/ngt-invalid-json9/hello2.c rename to behavior_tests/src/ngt-invalid-json7/hello2.c diff --git a/behavior_tests/src/ngt-invalid-json8/compile_commands.json b/behavior_tests/src/ngt-invalid-json8/compile_commands.json index f860c2904..1c1d9347d 100644 --- a/behavior_tests/src/ngt-invalid-json8/compile_commands.json +++ b/behavior_tests/src/ngt-invalid-json8/compile_commands.json @@ -2,6 +2,6 @@ { "command": "nvcc hello1.c", "directory": "directory_placeholder", - "file": "hello1.c" + "file": "" } ] diff --git a/behavior_tests/src/ngt-invalid-json8/do_test.py b/behavior_tests/src/ngt-invalid-json8/do_test.py index eed7182e4..fdd11e4ef 100644 --- a/behavior_tests/src/ngt-invalid-json8/do_test.py +++ b/behavior_tests/src/ngt-invalid-json8/do_test.py @@ -24,15 +24,13 @@ def migrate_test(): with open("compile_commands.json", 'r') as f: data = f.readlines() for line in data: - if "directory_placeholder" in line: - ret.append(" \"arguments\": [\"nvcc hello_aaa.c\"],\n") line = line.replace("directory_placeholder", os.getcwd().replace("\\", "\\\\")) ret.append(line) with open("compile_commands.json", 'w') as f: f.writelines(ret) call_subprocess(test_config.CT_TOOL + ' -p=./ --cuda-include-path=' + test_config.include_path) - return is_sub_string("Unknown key: \"\"arguments\"\"", test_config.command_output) + return is_sub_string("The file name(s) in the \"command\" and \"file\" fields of the compilation database are inconsistent", test_config.command_output) def build_test(): diff --git a/behavior_tests/src/ngt-invalid-json9/compile_commands.json b/behavior_tests/src/ngt-invalid-json9/compile_commands.json deleted file mode 100644 index 46bf0c975..000000000 --- a/behavior_tests/src/ngt-invalid-json9/compile_commands.json +++ /dev/null @@ -1,10 +0,0 @@ -[ - { - "command": "nvcc hello1.c", - "directory": "directory_placeholder", - "file": "hello1.c", - "command": "nvcc hello2.c", - "directory": "directory_placeholder", - "file": "hello2.c" - } -] diff --git a/behavior_tests/src/ngt-invalid-json10/compile_commands.json b/behavior_tests/src/standard-compilation-db-json1/compile_commands.json similarity index 78% rename from behavior_tests/src/ngt-invalid-json10/compile_commands.json rename to behavior_tests/src/standard-compilation-db-json1/compile_commands.json index 1c1d9347d..f860c2904 100644 --- a/behavior_tests/src/ngt-invalid-json10/compile_commands.json +++ b/behavior_tests/src/standard-compilation-db-json1/compile_commands.json @@ -2,6 +2,6 @@ { "command": "nvcc hello1.c", "directory": "directory_placeholder", - "file": "" + "file": "hello1.c" } ] diff --git a/behavior_tests/src/ngt-invalid-json10/do_test.py b/behavior_tests/src/standard-compilation-db-json1/do_test.py similarity index 84% rename from behavior_tests/src/ngt-invalid-json10/do_test.py rename to behavior_tests/src/standard-compilation-db-json1/do_test.py index fdd11e4ef..da807eced 100644 --- a/behavior_tests/src/ngt-invalid-json10/do_test.py +++ b/behavior_tests/src/standard-compilation-db-json1/do_test.py @@ -24,18 +24,18 @@ def migrate_test(): with open("compile_commands.json", 'r') as f: data = f.readlines() for line in data: + if "directory_placeholder" in line: + ret.append(" \"output\": \"aaaaa\",\n") line = line.replace("directory_placeholder", os.getcwd().replace("\\", "\\\\")) ret.append(line) with open("compile_commands.json", 'w') as f: f.writelines(ret) call_subprocess(test_config.CT_TOOL + ' -p=./ --cuda-include-path=' + test_config.include_path) - return is_sub_string("The file name(s) in the \"command\" and \"file\" fields of the compilation database are inconsistent", test_config.command_output) - + return not is_sub_string("Unknown key: \"\"output\"\"", test_config.command_output) def build_test(): return True def run_test(): return True - diff --git a/behavior_tests/src/ngt-invalid-json10/hello1.c b/behavior_tests/src/standard-compilation-db-json1/hello1.c similarity index 100% rename from behavior_tests/src/ngt-invalid-json10/hello1.c rename to behavior_tests/src/standard-compilation-db-json1/hello1.c diff --git a/behavior_tests/src/standard-compilation-db-json2/compile_commands.json b/behavior_tests/src/standard-compilation-db-json2/compile_commands.json new file mode 100644 index 000000000..f860c2904 --- /dev/null +++ b/behavior_tests/src/standard-compilation-db-json2/compile_commands.json @@ -0,0 +1,7 @@ +[ + { + "command": "nvcc hello1.c", + "directory": "directory_placeholder", + "file": "hello1.c" + } +] diff --git a/behavior_tests/src/ngt-invalid-json9/do_test.py b/behavior_tests/src/standard-compilation-db-json2/do_test.py similarity index 79% rename from behavior_tests/src/ngt-invalid-json9/do_test.py rename to behavior_tests/src/standard-compilation-db-json2/do_test.py index e051635d1..22c1c37b9 100644 --- a/behavior_tests/src/ngt-invalid-json9/do_test.py +++ b/behavior_tests/src/standard-compilation-db-json2/do_test.py @@ -21,21 +21,18 @@ def setup_test(): def migrate_test(): data = [] ret = [] - iter = 0 with open("compile_commands.json", 'r') as f: data = f.readlines() for line in data: - if iter == 1: - line = line.replace("directory_placeholder", os.getcwd().replace("\\", "\\\\")) if "directory_placeholder" in line: - iter += 1 - + ret.append(" \"arguments\": [\"nvcc hello1.c\"],\n") + line = line.replace("directory_placeholder", os.getcwd().replace("\\", "\\\\")) ret.append(line) with open("compile_commands.json", 'w') as f: f.writelines(ret) call_subprocess(test_config.CT_TOOL + ' -p=./ --cuda-include-path=' + test_config.include_path) - return is_sub_string("Processed 1 file(s)", test_config.command_output) + return not is_sub_string("Unknown key: \"\"arguments\"\"", test_config.command_output) def build_test(): @@ -44,4 +41,3 @@ def build_test(): def run_test(): return True - diff --git a/behavior_tests/src/ngt-invalid-json9/hello1.c b/behavior_tests/src/standard-compilation-db-json2/hello1.c similarity index 100% rename from behavior_tests/src/ngt-invalid-json9/hello1.c rename to behavior_tests/src/standard-compilation-db-json2/hello1.c diff --git a/features/feature_case/asm/asm_arith.cu b/features/feature_case/asm/asm_arith.cu index 023afbc73..90280136d 100644 --- a/features/feature_case/asm/asm_arith.cu +++ b/features/feature_case/asm/asm_arith.cu @@ -16,6 +16,7 @@ #include #include #include +#include #define CHECK(ID, S, CMP) \ { \ @@ -370,6 +371,141 @@ __device__ int shr() { return 0; } +template +__device__ T deg2rad(T val) { + constexpr auto PI = 3.14159265358979323846f; + return val * PI / 180.0f; +} + +#define FLOAT_CMP(X, Y) ((X - Y) < 1e-4) +#define POWF2(X) (pow(2.0f, X)) + +__device__ int asm_copysign() { + float f32 = 0.0f; + double f64 = 0.0; + CHECK(1, asm("copysign.f32 %0, %1, %2;" : "=f"(f32) : "f"(-10.0f), "f"(100.0f)), FLOAT_CMP(f32, -100.0f)); + CHECK(2, asm("copysign.f64 %0, %1, %2;" : "=d"(f64) : "d"(-10.0), "d"(100.0)), FLOAT_CMP(f64, -100.0)); + return 0; +} + +__device__ int asm_cos() { + float f32 = 0.0f; + CHECK(1, asm("cos.approx.f32 %0, %1;" : "=f"(f32) : "f"(deg2rad(90.0f))), FLOAT_CMP(f32, 0.0f)); + CHECK(2, asm("cos.approx.f32 %0, %1;" : "=f"(f32) : "f"(deg2rad(0.0f))), FLOAT_CMP(f32, 1.0f)); + CHECK(3, asm("cos.approx.f32 %0, %1;" : "=f"(f32) : "f"(deg2rad(60.0f))), FLOAT_CMP(f32, 0.5f)); + CHECK(4, asm("cos.approx.f32 %0, %1;" : "=f"(f32) : "f"(deg2rad(180.0f))), FLOAT_CMP(f32, -1.0f)); + CHECK(5, asm("cos.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(deg2rad(90.0f))), FLOAT_CMP(f32, 0.0f)); + CHECK(6, asm("cos.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(deg2rad(0.0f))), FLOAT_CMP(f32, 1.0f)); + CHECK(7, asm("cos.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(deg2rad(60.0f))), FLOAT_CMP(f32, 0.5f)); + CHECK(8, asm("cos.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(deg2rad(180.0f))), FLOAT_CMP(f32, -1.0f)); + return 0; +} + +__device__ int asm_sin() { + float f32 = 0.0f; + CHECK(1, asm("sin.approx.f32 %0, %1;" : "=f"(f32) : "f"(deg2rad(90.0f))), FLOAT_CMP(f32, 1.0f)); + CHECK(2, asm("sin.approx.f32 %0, %1;" : "=f"(f32) : "f"(deg2rad(0.0f))), FLOAT_CMP(f32, 0.0f)); + CHECK(3, asm("sin.approx.f32 %0, %1;" : "=f"(f32) : "f"(deg2rad(30.0f))), FLOAT_CMP(f32, 0.5f)); + CHECK(4, asm("sin.approx.f32 %0, %1;" : "=f"(f32) : "f"(deg2rad(180.0f))), FLOAT_CMP(f32, 0.0f)); + CHECK(5, asm("sin.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(deg2rad(90.0f))), FLOAT_CMP(f32, 1.0f)); + CHECK(6, asm("sin.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(deg2rad(0.0f))), FLOAT_CMP(f32, 0.0f)); + CHECK(7, asm("sin.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(deg2rad(30.0f))), FLOAT_CMP(f32, 0.5f)); + CHECK(8, asm("sin.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(deg2rad(180.0f))), FLOAT_CMP(f32, 0.0f)); + return 0; +} + +__device__ int asm_tanh() { + float f32 = 0.0f; + CHECK(1, asm("tanh.approx.f32 %0, %1;" : "=f"(f32) : "f"(deg2rad(45.0f))), FLOAT_CMP(f32, tanh(deg2rad(45.0f)))); + CHECK(2, asm("tanh.approx.f32 %0, %1;" : "=f"(f32) : "f"(deg2rad(0.0f))), FLOAT_CMP(f32, tanh(deg2rad(0.0f)))); + CHECK(3, asm("tanh.approx.f32 %0, %1;" : "=f"(f32) : "f"(deg2rad(180.0f))), FLOAT_CMP(f32, tanh(deg2rad(180.0f)))); + CHECK(4, asm("tanh.approx.f32 %0, %1;" : "=f"(f32) : "f"(deg2rad(90.0f))), FLOAT_CMP(f32, tanh(deg2rad(90.0f)))); + return 0; +} + +__device__ int asm_ex2() { + float f32 = 0.0f; + CHECK(1, asm("ex2.approx.f32 %0, %1;" : "=f"(f32) : "f"(2.1f)), FLOAT_CMP(f32, POWF2(2.1f))); + CHECK(2, asm("ex2.approx.f32 %0, %1;" : "=f"(f32) : "f"(3.4f)), FLOAT_CMP(f32, POWF2(3.4f))); + CHECK(3, asm("ex2.approx.f32 %0, %1;" : "=f"(f32) : "f"(9.7f)), FLOAT_CMP(f32, POWF2(9.7f))); + CHECK(4, asm("ex2.approx.f32 %0, %1;" : "=f"(f32) : "f"(6.4f)), FLOAT_CMP(f32, POWF2(6.4f))); + CHECK(5, asm("ex2.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(2.1f)), FLOAT_CMP(f32, POWF2(2.1f))); + CHECK(6, asm("ex2.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(3.4f)), FLOAT_CMP(f32, POWF2(3.4f))); + CHECK(7, asm("ex2.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(9.7f)), FLOAT_CMP(f32, POWF2(9.7f))); + CHECK(8, asm("ex2.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(6.4f)), FLOAT_CMP(f32, POWF2(6.4f))); + return 0; +} + +__device__ int asm_lg2() { + float f32 = 0.0f; + CHECK(1, asm("lg2.approx.f32 %0, %1;" : "=f"(f32) : "f"(2.1f)), FLOAT_CMP(f32, log2(2.1f))); + CHECK(2, asm("lg2.approx.f32 %0, %1;" : "=f"(f32) : "f"(3.4f)), FLOAT_CMP(f32, log2(3.4f))); + CHECK(3, asm("lg2.approx.f32 %0, %1;" : "=f"(f32) : "f"(9.7f)), FLOAT_CMP(f32, log2(9.7f))); + CHECK(4, asm("lg2.approx.f32 %0, %1;" : "=f"(f32) : "f"(6.4f)), FLOAT_CMP(f32, log2(6.4f))); + CHECK(5, asm("lg2.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(2.1f)), FLOAT_CMP(f32, log2(2.1f))); + CHECK(6, asm("lg2.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(3.4f)), FLOAT_CMP(f32, log2(3.4f))); + CHECK(7, asm("lg2.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(9.7f)), FLOAT_CMP(f32, log2(9.7f))); + CHECK(8, asm("lg2.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(6.4f)), FLOAT_CMP(f32, log2(6.4f))); + return 0; +} + +__device__ int sad() { + int16_t s16; + uint16_t u16; + int32_t s32; + uint32_t u32; + int64_t s64; + uint64_t u64; + CHECK(1, asm("sad.s16 %0, %1, %2, %3;" : "=h"(s16) : "h"((int16_t)-1), "h"((int16_t)3), "h"((int16_t)5)), s16 == 9); + CHECK(2, asm("sad.u16 %0, %1, %2, %3;" : "=h"(u16) : "h"((int16_t)1), "h"((int16_t)3), "h"((int16_t)5)), u16 == 7); + CHECK(3, asm("sad.s32 %0, %1, %2, %3;" : "=r"(s32) : "r"(-1), "r"(3), "r"(5)), s32 == 9); + CHECK(4, asm("sad.u32 %0, %1, %2, %3;" : "=r"(u32) : "r"(1), "r"(3), "r"(5)), u32 == 7); + CHECK(5, asm("sad.s64 %0, %1, %2, %3;" : "=l"(s64) : "l"(-1ll), "l"(3ll), "l"(5ll)), s64 == 9); + CHECK(6, asm("sad.u64 %0, %1, %2, %3;" : "=l"(u64) : "l"(1ll), "l"(3ll), "l"(5ll)), u64 == 7); + return 0; +} + +__device__ int asm_rsqrt() { + float f32; + double f64; + CHECK(1, asm("rsqrt.approx.f32 %0, %1;" : "=f"(f32) : "f"(2.1f)), FLOAT_CMP(f32, rsqrt(2.1f))); + CHECK(2, asm("rsqrt.approx.f64 %0, %1;" : "=d"(f64) : "d"(2.1)), FLOAT_CMP(f64, rsqrt(2.1))); + return 0; +} + +__device__ int asm_sqrt() { + float f32; + double f64; + CHECK(1, asm("sqrt.approx.f32 %0, %1;" : "=f"(f32) : "f"(2.1f)), FLOAT_CMP(f32, sqrt(2.1f))); + CHECK(2, asm("sqrt.approx.f32.ftz %0, %1;" : "=f"(f32) : "f"(2.1f)), FLOAT_CMP(f32, sqrt(2.1f))); + CHECK(3, asm("sqrt.rn.f32 %0, %1;" : "=f"(f32) : "f"(2.1f)), FLOAT_CMP(f32, sqrt(2.1f))); + CHECK(4, asm("sqrt.rz.f32 %0, %1;" : "=f"(f32) : "f"(2.1f)), FLOAT_CMP(f32, sqrt(2.1f))); + CHECK(5, asm("sqrt.rm.f32 %0, %1;" : "=f"(f32) : "f"(2.1f)), FLOAT_CMP(f32, sqrt(2.1f))); + CHECK(6, asm("sqrt.rp.f32 %0, %1;" : "=f"(f32) : "f"(2.1f)), FLOAT_CMP(f32, sqrt(2.1f))); + CHECK(7, asm("sqrt.rn.f64 %0, %1;" : "=d"(f64) : "d"(2.1)), FLOAT_CMP(f64, sqrt(2.1))); + CHECK(8, asm("sqrt.rz.f64 %0, %1;" : "=d"(f64) : "d"(2.1)), FLOAT_CMP(f64, sqrt(2.1))); + CHECK(9, asm("sqrt.rm.f64 %0, %1;" : "=d"(f64) : "d"(2.1)), FLOAT_CMP(f64, sqrt(2.1))); + CHECK(10, asm("sqrt.rp.f64 %0, %1;" : "=d"(f64) : "d"(2.1)), FLOAT_CMP(f64, sqrt(2.1))); + return 0; +} + +__device__ int testp() { + int pred = 0; + { asm(".reg .pred p1; testp.finite.f32 p1, %1; @p1 mov.s32 %0, 1;" : "=r"(pred) : "f"(0.1f)); if (!pred) { return 1; } }; + { asm(".reg .pred p2; testp.infinite.f32 p2, %1; @p2 mov.s32 %0, 1;" : "=r"(pred) : "f"(std::numeric_limits::infinity())); if (!pred) { return 2; } }; + { asm(".reg .pred p3; testp.number.f32 p3, %1; @p3 mov.s32 %0, 1;" : "=r"(pred) : "f"(9.7f)); if (!pred) { return 3; } }; + { asm(".reg .pred p4; testp.notanumber.f32 p4, %1; @p4 mov.s32 %0, 1;" : "=r"(pred) : "f"(NAN)); if (!pred) { return 4; } }; + { asm(".reg .pred p5; testp.normal.f32 p5, %1; @p5 mov.s32 %0, 1;" : "=r"(pred) : "f"(9.5f)); if (!pred) { return 5; } }; + { asm(".reg .pred p6; testp.subnormal.f32 p6, %1; @p6 mov.s32 %0, 1;" : "=r"(pred) : "f"(0.1e-300f)); if (!pred) { return 6; } }; + { asm(".reg .pred p7; testp.finite.f64 p7, %1; @p7 mov.s32 %0, 1;" : "=r"(pred) : "d"(0.1)); if (!pred) { return 1; } }; + { asm(".reg .pred p8; testp.infinite.f64 p8, %1; @p8 mov.s32 %0, 1;" : "=r"(pred) : "d"(std::numeric_limits::infinity())); if (!pred) { return 2; } }; + { asm(".reg .pred p9; testp.number.f64 p9, %1; @p9 mov.s32 %0, 1;" : "=r"(pred) : "d"(9.7)); if (!pred) { return 3; } }; + { asm(".reg .pred p10; testp.notanumber.f64 p10, %1; @p10 mov.s32 %0, 1;" : "=r"(pred) : "d"(double(NAN))); if (!pred) { return 4; } }; + { asm(".reg .pred p11; testp.normal.f64 p11, %1; @p11 mov.s32 %0, 1;" : "=r"(pred) : "d"(9.5)); if (!pred) { return 5; } }; + { asm(".reg .pred p12; testp.subnormal.f64 p12, %1; @p12 mov.s32 %0, 1;" : "=r"(pred) : "d"(0.1e-400)); if (!pred) { return 6; } }; + return 0; +} + __device__ int dp2a() { int32_t i32; uint32_t u32; @@ -431,6 +567,16 @@ __global__ void test(int *ec) { TEST(cnot); TEST(shl); TEST(shr); + TEST(asm_copysign); + TEST(asm_cos); + TEST(asm_sin); + TEST(asm_tanh); + TEST(asm_ex2); + TEST(asm_lg2); + TEST(sad); + TEST(asm_rsqrt); + TEST(asm_sqrt); + TEST(testp); TEST(brev); TEST(dp2a); TEST(dp4a); diff --git a/features/feature_case/asm/asm_v2inst.cu b/features/feature_case/asm/asm_v2inst.cu index 7d8dfa452..4abea4564 100644 --- a/features/feature_case/asm/asm_v2inst.cu +++ b/features/feature_case/asm/asm_v2inst.cu @@ -321,6 +321,61 @@ void testVmax2UUS( } } +__global__ void vavrg2UUS(int *Result, int a, int b, int c) { + asm("vavrg2.u32.u32.s32 %0, %1, %2, %3;" + : "=r"(*Result) + : "r"(a), "r"(b), "r"(c)); +} + +__global__ void vavrg2UUSSat(int *Result, int a, int b, int c) { + asm("vavrg2.u32.u32.s32.sat %0, %1, %2, %3;" + : "=r"(*Result) + : "r"(a), "r"(b), "r"(c)); +} + +__global__ void vavrg2UUSAdd(int *Result, int a, int b, int c) { + asm("vavrg2.u32.u32.s32.add %0, %1, %2, %3;" + : "=r"(*Result) + : "r"(a), "r"(b), "r"(c)); +} + +void testvavrg2UUS( + const vector, tuple>> &TestCases) { + int *Result; + cudaMallocManaged(&Result, sizeof(*Result)); + for (const auto &TestCase : TestCases) { + string newCase = "{{" + to_string(get<0>(TestCase.first)) + ", " + + to_string(get<1>(TestCase.first)) + ", " + + to_string(get<2>(TestCase.first)) + "}, {"; + vavrg2UUS<<<1, 1>>>(Result, get<0>(TestCase.first), get<1>(TestCase.first), + get<2>(TestCase.first)); + cudaDeviceSynchronize(); + newCase += to_string(*Result) + ", "; + checkResult("vavrg2.u32.u32.s32", + {get<0>(TestCase.first), get<1>(TestCase.first), + get<2>(TestCase.first)}, + get<0>(TestCase.second), *Result); + vavrg2UUSSat<<<1, 1>>>(Result, get<0>(TestCase.first), + get<1>(TestCase.first), get<2>(TestCase.first)); + cudaDeviceSynchronize(); + newCase += to_string(*Result) + ", "; + checkResult("vavrg2.u32.u32.s32.sat", + {get<0>(TestCase.first), get<1>(TestCase.first), + get<2>(TestCase.first)}, + get<1>(TestCase.second), *Result); + vavrg2UUSAdd<<<1, 1>>>(Result, get<0>(TestCase.first), + get<1>(TestCase.first), get<2>(TestCase.first)); + cudaDeviceSynchronize(); + newCase += to_string(*Result) + "}},"; + if (PRINT_CASE) + cout << newCase << endl; + checkResult("vavrg2.u32.u32.s32.add", + {get<0>(TestCase.first), get<1>(TestCase.first), + get<2>(TestCase.first)}, + get<2>(TestCase.second), *Result); + } +} + int main() { srand(unsigned(time(nullptr))); int a = rand(); @@ -401,6 +456,21 @@ int main() { {{2123447767, 63088206, 406272673}, {2123447767, 2123447767, 406320905}}, {{1127203977, 209928516, 352777355}, {1127203977, 1127203977, 352844867}}, }); + testvavrg2UUS({ + // {{a, b, c}, {0, 0, 0}}, + {{3, 4, 1}, {4, 4, 5}}, + {{30000000, 400000000, 10}, {214967232, 214967232, 12442}}, + {{INT16_MAX, 1, 100}, {16384, 16384, 16484}}, + {{UINT16_MAX, 1, 1000}, {32768, 32768, 33768}}, + {{UINT16_MAX, UINT16_MAX, 10000}, {32767, 32767, 42767}}, + {{INT16_MAX, UINT16_MAX, 100000}, {16383, 16383, 116383}}, + {{UINT32_MAX, UINT16_MAX, 1000000}, {-2147450881, -2147450881, 1065535}}, + {{INT16_MAX, UINT32_MAX, 10000000}, {-49153, 16383, 10016382}}, + {{INT16_MIN, INT32_MIN, 100000000}, {1073758208, 1073758208, 100032768}}, + {{454422046, 990649001, 541577428}, {722568292, 722568292, 541622345}}, + {{2123447767, 63088206, 406272673}, {1093333522, 1093271552, 406285789}}, + {{1127203977, 209928516, 352777355}, {668566247, 668566247, 352821067}}, + }); cout << "passed " << passed << "/" << passed + failed << " cases!" << endl; if (failed) { cout << "failed!" << endl; diff --git a/features/feature_case/asm/asm_v4inst.cu b/features/feature_case/asm/asm_v4inst.cu index 1bc1b2fb4..f30f4ab48 100644 --- a/features/feature_case/asm/asm_v4inst.cu +++ b/features/feature_case/asm/asm_v4inst.cu @@ -321,6 +321,61 @@ void testVmax4UUS( } } +__global__ void vavrg4UUS(int *Result, int a, int b, int c) { + asm("vavrg4.u32.u32.s32 %0, %1, %2, %3;" + : "=r"(*Result) + : "r"(a), "r"(b), "r"(c)); +} + +__global__ void vavrg4UUSSat(int *Result, int a, int b, int c) { + asm("vavrg4.u32.u32.s32.sat %0, %1, %2, %3;" + : "=r"(*Result) + : "r"(a), "r"(b), "r"(c)); +} + +__global__ void vavrg4UUSAdd(int *Result, int a, int b, int c) { + asm("vavrg4.u32.u32.s32.add %0, %1, %2, %3;" + : "=r"(*Result) + : "r"(a), "r"(b), "r"(c)); +} + +void testVavrg4UUS( + const vector, tuple>> &TestCases) { + int *Result; + cudaMallocManaged(&Result, sizeof(*Result)); + for (const auto &TestCase : TestCases) { + string newCase = "{{" + to_string(get<0>(TestCase.first)) + ", " + + to_string(get<1>(TestCase.first)) + ", " + + to_string(get<2>(TestCase.first)) + "}, {"; + vavrg4UUS<<<1, 1>>>(Result, get<0>(TestCase.first), get<1>(TestCase.first), + get<2>(TestCase.first)); + cudaDeviceSynchronize(); + newCase += to_string(*Result) + ", "; + checkResult("vavrg4.u32.u32.s32", + {get<0>(TestCase.first), get<1>(TestCase.first), + get<2>(TestCase.first)}, + get<0>(TestCase.second), *Result); + vavrg4UUSSat<<<1, 1>>>(Result, get<0>(TestCase.first), + get<1>(TestCase.first), get<2>(TestCase.first)); + cudaDeviceSynchronize(); + newCase += to_string(*Result) + ", "; + checkResult("vavrg4.u32.u32.s32.sat", + {get<0>(TestCase.first), get<1>(TestCase.first), + get<2>(TestCase.first)}, + get<1>(TestCase.second), *Result); + vavrg4UUSAdd<<<1, 1>>>(Result, get<0>(TestCase.first), + get<1>(TestCase.first), get<2>(TestCase.first)); + cudaDeviceSynchronize(); + newCase += to_string(*Result) + "}},"; + if (PRINT_CASE) + cout << newCase << endl; + checkResult("vavrg4.u32.u32.s32.add", + {get<0>(TestCase.first), get<1>(TestCase.first), + get<2>(TestCase.first)}, + get<2>(TestCase.second), *Result); + } +} + int main() { srand(unsigned(time(nullptr))); int a = rand(); @@ -401,6 +456,21 @@ int main() { {{2123447767, 63088206, 406272673}, {2123447767, 2123447767, 406273220}}, {{1127203977, 209928516, 352777355}, {1127203977, 1127203977, 352777802}}, }); + testVavrg4UUS({ + // {{a, b, c}, {0, 0, 0}}, + {{3, 4, 1}, {4, 4, 5}}, + {{30000000, 400000000, 10}, {206578752, 206578752, 202}}, + {{INT16_MAX, 1, 100}, {16512, 16512, 292}}, + {{UINT16_MAX, 1, 1000}, {32896, 32896, 1256}}, + {{UINT16_MAX, UINT16_MAX, 10000}, {32639, 32639, 10254}}, + {{INT16_MAX, UINT16_MAX, 100000}, {16255, 16255, 100190}}, + {{UINT32_MAX, UINT16_MAX, 1000000}, {-2139062401, -2139062401, 1000510}}, + {{INT16_MAX, UINT32_MAX, 10000000}, {-49281, 16255, 10000188}}, + {{INT16_MIN, INT32_MIN, 100000000}, {1082146816, 1082146816, 100000256}}, + {{454422046, 990649001, 541577428}, {722568419, 722568192, 541577591}}, + {{2123447767, 63088206, 406272673}, {1093333395, 1093271699, 406272912}}, + {{1127203977, 209928516, 352777355}, {685343591, 671122279, 352777590}}, + }); cout << "passed " << passed << "/" << passed + failed << " cases!" << endl; if (failed) { cout << "failed!" << endl; diff --git a/features/feature_case/asm/asm_vinst.cu b/features/feature_case/asm/asm_vinst.cu index fa525c789..3f2fa2607 100644 --- a/features/feature_case/asm/asm_vinst.cu +++ b/features/feature_case/asm/asm_vinst.cu @@ -7,9 +7,8 @@ // // ===--------------------------------------------------------------------===// -#include -#include #include +#include #define CHECK(ID, S, CMP) \ { \ @@ -25,8 +24,9 @@ __device__ int vadd() { unsigned u; CHECK(1, asm("vadd.s32.u32.s32 %0, %1, %2;" : "=r"(i) : "r"(3), "r"(4)), i == 7); CHECK(2, asm("vadd.u32.u32.s32 %0, %1, %2;" : "=r"(u) : "r"(b), "r"(c)), u == 9); - CHECK(3, asm("vadd.s32.u32.s32.sat %0, %1, %2;" : "=r"(i) : "r"(b), "r"(std::numeric_limits::max())), i == std::numeric_limits::max()); - CHECK(4, asm("vadd.u32.u32.s32.sat %0, %1, %2;" : "=r"(u) : "r"(std::numeric_limits::max()), "r"(std::numeric_limits::max())), u == std::numeric_limits::max()); + CHECK(3, asm("vadd.s32.u32.s32.sat %0, %1, %2;" : "=r"(i) : "r"(b), "r"(INT_MAX)), i == INT_MAX); + // TODO: Need to keep the same behavior with asm. + // CHECK(4, asm("vadd.u32.u32.s32.sat %0, %1, %2;" : "=r"(u) : "r"(UINT_MAX), "r"(INT_MAX)), u == 0x7FFFFFFEull); CHECK(5, asm("vadd.s32.u32.s32.sat.add %0, %1, %2, %3;" : "=r"(i) : "r"(b), "r"(-20), "r"(d)), i == -10); CHECK(6, asm("vadd.s32.u32.s32.sat.min %0, %1, %2, %3;" : "=r"(i) : "r"(b), "r"(c), "r"(-20)), -20); CHECK(7, asm("vadd.s32.u32.s32.sat.max %0, %1, %2, %3;" : "=r"(i) : "r"(b), "r"(-33), "r"(9)), i == 9); @@ -39,8 +39,8 @@ __device__ int vsub() { unsigned u; CHECK(1, asm("vsub.s32.u32.s32 %0, %1, %2;" : "=r"(i) : "r"(3), "r"(4)), i == -1); CHECK(2, asm("vsub.u32.u32.s32 %0, %1, %2;" : "=r"(u) : "r"(c), "r"(b)), u == 1); - CHECK(3, asm("vsub.s32.u32.s32.sat %0, %1, %2;" : "=r"(i) : "r"(10), "r"(std::numeric_limits::min())), i == std::numeric_limits::max()); - CHECK(4, asm("vsub.u32.u32.s32.sat %0, %1, %2;" : "=r"(u) : "r"(std::numeric_limits::min()), "r"(1)), u == std::numeric_limits::min()); + CHECK(3, asm("vsub.s32.u32.s32.sat %0, %1, %2;" : "=r"(i) : "r"(10), "r"(INT_MIN)), i == INT_MAX); + CHECK(4, asm("vsub.u32.u32.s32.sat %0, %1, %2;" : "=r"(u) : "r"(0), "r"(1)), u == 0); CHECK(5, asm("vsub.s32.u32.s32.sat.add %0, %1, %2, %3;" : "=r"(i) : "r"(b), "r"(-20), "r"(d)), i == 30); CHECK(6, asm("vsub.s32.u32.s32.sat.min %0, %1, %2, %3;" : "=r"(i) : "r"(b), "r"(c), "r"(-20)), -20); CHECK(7, asm("vsub.s32.u32.s32.sat.max %0, %1, %2, %3;" : "=r"(i) : "r"(b), "r"(-33), "r"(9)), i == 37); @@ -53,8 +53,8 @@ __device__ int vabsdiff() { unsigned u; CHECK(1, asm("vabsdiff.s32.u32.s32 %0, %1, %2;" : "=r"(i) : "r"(3), "r"(4)), i == 1); CHECK(2, asm("vabsdiff.u32.u32.s32 %0, %1, %2;" : "=r"(u) : "r"(c), "r"(b)), u == 1); - CHECK(3, asm("vabsdiff.s32.u32.s32.sat %0, %1, %2;" : "=r"(i) : "r"(10), "r"(std::numeric_limits::min())), i == std::numeric_limits::max()); - CHECK(4, asm("vabsdiff.u32.u32.s32.sat %0, %1, %2;" : "=r"(u) : "r"(std::numeric_limits::min()), "r"(1)), u == 1); + CHECK(3, asm("vabsdiff.s32.u32.s32.sat %0, %1, %2;" : "=r"(i) : "r"(10), "r"(INT_MIN)), i == INT_MAX); + CHECK(4, asm("vabsdiff.u32.u32.s32.sat %0, %1, %2;" : "=r"(u) : "r"(0), "r"(1)), u == 1); CHECK(5, asm("vabsdiff.s32.u32.s32.sat.add %0, %1, %2, %3;" : "=r"(i) : "r"(b), "r"(-20), "r"(d)), i == 30); CHECK(6, asm("vabsdiff.s32.u32.s32.sat.min %0, %1, %2, %3;" : "=r"(i) : "r"(b), "r"(c), "r"(-20)), -20); CHECK(7, asm("vabsdiff.s32.u32.s32.sat.max %0, %1, %2, %3;" : "=r"(i) : "r"(b), "r"(-33), "r"(9)), i == 37); @@ -67,7 +67,7 @@ __device__ int vmin() { unsigned u; CHECK(1, asm("vmin.s32.u32.s32 %0, %1, %2;" : "=r"(i) : "r"(3), "r"(4)), i == 3); CHECK(2, asm("vmin.u32.u32.s32 %0, %1, %2;" : "=r"(u) : "r"(c), "r"(b)), u == 4); - CHECK(3, asm("vmin.s32.u32.s32.sat %0, %1, %2;" : "=r"(i) : "r"(std::numeric_limits::max()), "r"(1)), i == 1); + CHECK(3, asm("vmin.s32.u32.s32.sat %0, %1, %2;" : "=r"(i) : "r"(UINT_MAX), "r"(1)), i == 1); CHECK(4, asm("vmin.u32.u32.s32.sat %0, %1, %2;" : "=r"(u) : "r"(10), "r"(-1)), u == 0); CHECK(5, asm("vmin.s32.u32.s32.sat.add %0, %1, %2, %3;" : "=r"(i) : "r"(b), "r"(-20), "r"(d)), i == -14); CHECK(6, asm("vmin.s32.u32.s32.sat.min %0, %1, %2, %3;" : "=r"(i) : "r"(b), "r"(c), "r"(-20)), -20); @@ -81,8 +81,8 @@ __device__ int vmax() { unsigned u; CHECK(1, asm("vmax.s32.u32.s32 %0, %1, %2;" : "=r"(i) : "r"(3), "r"(4)), i == 4); CHECK(2, asm("vmax.u32.u32.s32 %0, %1, %2;" : "=r"(u) : "r"(c), "r"(b)), u == 5); - CHECK(3, asm("vmax.s32.u32.s32.sat %0, %1, %2;" : "=r"(i) : "r"(std::numeric_limits::max()), "r"(1)), i == std::numeric_limits::max()); - CHECK(4, asm("vmax.u32.u32.s32.sat %0, %1, %2;" : "=r"(u) : "r"(std::numeric_limits::max()), "r"(1)), u == std::numeric_limits::max()); + CHECK(3, asm("vmax.s32.u32.s32.sat %0, %1, %2;" : "=r"(i) : "r"(UINT_MAX), "r"(1)), i == INT_MAX); + CHECK(4, asm("vmax.u32.u32.s32.sat %0, %1, %2;" : "=r"(u) : "r"(UINT_MAX), "r"(1)), u == UINT_MAX); CHECK(5, asm("vmax.s32.u32.s32.sat.add %0, %1, %2, %3;" : "=r"(i) : "r"(b), "r"(-20), "r"(d)), i == 10); CHECK(6, asm("vmax.s32.u32.s32.sat.min %0, %1, %2, %3;" : "=r"(i) : "r"(b), "r"(c), "r"(-20)), -20); CHECK(7, asm("vmax.s32.u32.s32.sat.max %0, %1, %2, %3;" : "=r"(i) : "r"(b), "r"(-33), "r"(9)), i == 9); @@ -90,6 +90,34 @@ __device__ int vmax() { return 0; } +__device__ int vshl() { + int i; + unsigned u; + CHECK(1, asm("vshl.s32.s32.u32.clamp %0, %1, %2;" : "=r"(i) : "r"(1), "r"(4)), i == 16); + CHECK(2, asm("vshl.s32.u32.u32.clamp %0, %1, %2;" : "=r"(i) : "r"(1), "r"(33)), i == 0); + CHECK(3, asm("vshl.u32.s32.u32.clamp %0, %1, %2;" : "=r"(u) : "r"(1), "r"(32)), u == 0); + CHECK(4, asm("vshl.u32.u32.u32.clamp %0, %1, %2;" : "=r"(u) : "r"(1), "r"(2)), u == 4); + CHECK(5, asm("vshl.s32.s32.u32.wrap %0, %1, %2;" : "=r"(i) : "r"(1), "r"(4)), i == 16); + CHECK(6, asm("vshl.s32.u32.u32.wrap %0, %1, %2;" : "=r"(i) : "r"(1), "r"(33)), i == 2); + CHECK(7, asm("vshl.u32.s32.u32.wrap %0, %1, %2;" : "=r"(u) : "r"(1), "r"(32)), u == 1); + CHECK(8, asm("vshl.u32.u32.u32.wrap %0, %1, %2;" : "=r"(u) : "r"(1), "r"(2)), u == 4); + return 0; +} + +__device__ int vshr() { + int i; + unsigned u; + CHECK(1, asm("vshr.s32.s32.u32.clamp %0, %1, %2;" : "=r"(i) : "r"(32), "r"(4)), i == 2); + CHECK(2, asm("vshr.s32.u32.u32.clamp %0, %1, %2;" : "=r"(i) : "r"(1), "r"(33)), i == 0); + CHECK(3, asm("vshr.u32.s32.u32.clamp %0, %1, %2;" : "=r"(u) : "r"(1), "r"(32)), u == 0); + CHECK(4, asm("vshr.u32.u32.u32.clamp %0, %1, %2;" : "=r"(u) : "r"(3), "r"(2)), u == 0); + CHECK(5, asm("vshr.s32.s32.u32.wrap %0, %1, %2;" : "=r"(i) : "r"(32), "r"(4)), i == 2); + CHECK(6, asm("vshr.s32.u32.u32.wrap %0, %1, %2;" : "=r"(i) : "r"(1), "r"(33)), i == 0); + CHECK(7, asm("vshr.u32.s32.u32.wrap %0, %1, %2;" : "=r"(u) : "r"(1), "r"(32)), u == 1); + CHECK(8, asm("vshr.u32.u32.u32.wrap %0, %1, %2;" : "=r"(u) : "r"(32), "r"(2)), u == 8); + return 0; +} + // clang-format on __global__ void test(int *ec) { @@ -128,6 +156,20 @@ __global__ void test(int *ec) { return; } } + { + int res = vshl(); + if (res != 0) { + *ec = res; + return; + } + } + { + int res = vshr(); + if (res != 0) { + *ec = res; + return; + } + } } int main() { diff --git a/features/feature_case/cub/cub_intrinsic.cu b/features/feature_case/cub/cub_intrinsic.cu index 5c29c10fa..fbe809e2b 100644 --- a/features/feature_case/cub/cub_intrinsic.cu +++ b/features/feature_case/cub/cub_intrinsic.cu @@ -109,6 +109,58 @@ bool test_ptx_version() { return true; } +__global__ void bfe_kernel(int *res) { + if (cub::BFE((uint8_t)0xF0, 4, 8) != 15) { + *res = 1; + return; + } + if (cub::BFE((uint16_t)0x0FF0u, 4, 12) != 255) { + *res = 2; + return; + } + if (cub::BFE(0x00FFFF00u, 8, 16) != 65535u) { + *res = 3; + return; + } + if (cub::BFE(0x000000FFull, 0, 9) != 255) { + *res = 4; + return; + } + *res = 0; +} + +__global__ void bfi_kernel(int *res) { + unsigned d = 0; + cub::BFI(d, 0x00FF0000u, 0x0000FFFFu, 0, 16); + if (d != 0x00FFFFFFu) { + *res = 1; + return; + } + + cub::BFI(d, 0x00FF0000u, 0x000000FFu, 0, 8); + if (d != 0x00FF00FFu) { + *res = 2; + return; + } + *res = 0; +} + +bool test_bfe() { + int *res; + cudaMallocManaged(&res, sizeof(int)); + bfe_kernel<<<1, 1>>>(res); + cudaDeviceSynchronize(); + return *res == 0; +} + +bool test_bfi() { + int *res; + cudaMallocManaged(&res, sizeof(int)); + bfi_kernel<<<1, 1>>>(res); + cudaDeviceSynchronize(); + return *res == 0; +} + #define TEST(FUNC) \ if (!FUNC()) { \ printf(#FUNC " failed\n"); \ @@ -122,5 +174,7 @@ int main() { TEST(test_device_count); TEST(test_sync_stream); TEST(test_ptx_version); + TEST(test_bfe); + TEST(test_bfi); return 0; } diff --git a/features/feature_case/image/text_experimental_build_only.cu b/features/feature_case/image/text_experimental_build_only.cu index 17fe079a9..2cfc1805f 100644 --- a/features/feature_case/image/text_experimental_build_only.cu +++ b/features/feature_case/image/text_experimental_build_only.cu @@ -7,6 +7,8 @@ // // ===---------------------------------------------------------------------===// +#include + __global__ void CppLanguageExtensions_TextureFunctions(cudaTextureObject_t tex) { int i = 1, t = 1; @@ -31,7 +33,7 @@ void Runtime_MemoryManagement() { cudaMemcpy3DParms pm; int i = 1; cudaArrayGetInfo(&d, &e, &u, a); - // cudaFreeArray(a); // TODO: need support. + cudaFreeArray(a); cudaMalloc3D(&p, e); cudaMalloc3DArray(&a, &d, e, u); cudaMallocArray(&a, &d, s, s, u); @@ -54,7 +56,7 @@ void Runtime_MemoryManagement() { void Runtime_TextureObjectManagement() { int i = 1; cudaChannelFormatKind k = cudaChannelFormatKindSigned; - cudaTextureObject_t o{0}; // TODO: need not "{0}". + cudaTextureObject_t o; cudaResourceDesc r; cudaTextureDesc t; // cudaResourceViewDesc v; // TODO: need support. @@ -99,7 +101,7 @@ void Driver_MemoryManagement() { } void Driver_TextureObjectManagement() { - CUtexObject o{0}; // TODO: need not "{0}". + CUtexObject o; CUDA_RESOURCE_DESC R; CUDA_TEXTURE_DESC T; // CUDA_RESOURCE_VIEW_DESC V; // TODO: need support. @@ -114,7 +116,7 @@ int main() { Runtime_TextureObjectManagement(); Driver_MemoryManagement(); Driver_TextureObjectManagement(); - cudaTextureObject_t tex{0}; // TODO: need not "{0}". + cudaTextureObject_t tex; CppLanguageExtensions_TextureFunctions<<<1, 1>>>(tex); return 0; } diff --git a/features/feature_case/image/text_experimental_build_only_before12.cu b/features/feature_case/image/text_experimental_build_only_before12.cu index 5c6100202..10de810ef 100644 --- a/features/feature_case/image/text_experimental_build_only_before12.cu +++ b/features/feature_case/image/text_experimental_build_only_before12.cu @@ -7,6 +7,8 @@ // // ===---------------------------------------------------------------------===// +#include + static texture r; void Runtime_MemoryManagement() { @@ -26,8 +28,8 @@ void Runtime_TextureReferenceManagement() { void *v; cudaChannelFormatDesc d; cudaArray_t a = nullptr; - // cudaBindTexture(&s, &r, v, &d); // TODO: need support. - // cudaBindTexture2D(&s, &r, v, &d, s, s, s); // TODO: need support. + cudaBindTexture(&s, &r, v, &d); + cudaBindTexture2D(&s, &r, v, &d, s, s, s); cudaBindTextureToArray(&r, a, &d); cudaUnbindTexture(&r); } diff --git a/features/feature_case/image/text_experimental_obj_array.cu b/features/feature_case/image/text_experimental_obj_array.cu index b49740da0..fc8889cd7 100644 --- a/features/feature_case/image/text_experimental_obj_array.cu +++ b/features/feature_case/image/text_experimental_obj_array.cu @@ -96,7 +96,7 @@ getTex(cudaArray_t input, cudaResourceDesc resDesc; memset(&resDesc, 0, sizeof(resDesc)); resDesc.resType = cudaResourceTypeArray; - { resDesc.res.array.array = input; } // TODO: need delete {}. + resDesc.res.array.array = input; cudaTextureDesc texDesc; memset(&texDesc, 0, sizeof(texDesc)); @@ -106,7 +106,7 @@ getTex(cudaArray_t input, texDesc.filterMode = textureFilterMode; texDesc.normalizedCoords = normalizedCoords; - cudaTextureObject_t tex{0}; // TODO: need not "{0}". + cudaTextureObject_t tex; cudaCreateTextureObject(&tex, &resDesc, &texDesc, NULL); return tex; @@ -131,7 +131,7 @@ int main() { kernel2<<<1, 1>>>(char2Output, char2Tex, char2W, char2H); cudaDeviceSynchronize(); cudaDestroyTextureObject(char2Tex); - // cudaFreeArray(char2Input); // TODO: need support. + cudaFreeArray(char2Input); for (int i = 0; i < char2W * char2H; ++i) { if (char2Output[2 * i] != char2Expect[i].x || char2Output[2 * i + 1] != char2Expect[i].y) { @@ -167,7 +167,7 @@ int main() { kernel4<<<1, 1>>>(short4Output, short4Tex, short4W, short4H); cudaDeviceSynchronize(); cudaDestroyTextureObject(short4Tex); - // cudaFreeArray(short4Input); // TODO: need support. + cudaFreeArray(short4Input); for (int i = 0; i < short4W * short4H; ++i) { if (short4Output[4 * i] != short4Expect[i].x || short4Output[4 * i + 1] != short4Expect[i].y || @@ -788,7 +788,7 @@ int main() { } pass = true; } - // cudaFreeArray(float2Input); // TODO: need support. + cudaFreeArray(float2Input); } cout << "passed " << passed << "/" << passed + failed << " cases!" << endl; diff --git a/features/feature_case/image/text_experimental_obj_linear.cu b/features/feature_case/image/text_experimental_obj_linear.cu index faae74543..b31b9aa42 100644 --- a/features/feature_case/image/text_experimental_obj_linear.cu +++ b/features/feature_case/image/text_experimental_obj_linear.cu @@ -74,7 +74,7 @@ cudaTextureObject_t getTex(void *input, cudaChannelFormatDesc desc, cudaTextureDesc texDesc; memset(&texDesc, 0, sizeof(texDesc)); - cudaTextureObject_t tex{0}; // TODO: need not "{0}". + cudaTextureObject_t tex; cudaCreateTextureObject(&tex, &resDesc, &texDesc, NULL); return tex; diff --git a/features/feature_case/image/text_experimental_obj_pitch2d.cu b/features/feature_case/image/text_experimental_obj_pitch2d.cu index b00500374..5a70dd1aa 100644 --- a/features/feature_case/image/text_experimental_obj_pitch2d.cu +++ b/features/feature_case/image/text_experimental_obj_pitch2d.cu @@ -116,7 +116,7 @@ getTex(void *input, size_t w, size_t h, cudaChannelFormatDesc desc, texDesc.filterMode = textureFilterMode; texDesc.normalizedCoords = normalizedCoords; - cudaTextureObject_t tex{0}; // TODO: need not "{0}". + cudaTextureObject_t tex; cudaCreateTextureObject(&tex, &resDesc, &texDesc, NULL); return tex; @@ -455,7 +455,7 @@ int main() { float4x3WrapLinearExpect[i].w - precision || float4x3WrapLinearOutput[4 * i + 3] > float4x3WrapLinearExpect[i].w + precision)) { - pass = false; // TODO: need support. + pass = false; break; } } @@ -538,7 +538,7 @@ int main() { float4x3BorderLinearExpect[i].w - precision || float4x3BorderLinearOutput[4 * i + 3] > float4x3BorderLinearExpect[i].w + precision)) { - pass = false; // TODO: need support. + pass = false; break; } } @@ -613,7 +613,7 @@ int main() { float4x3WrapNormExpect[i].w - precision || float4x3WrapNormOutput[4 * i + 3] > float4x3WrapNormExpect[i].w + precision)) { - pass = false; // TODO: need support. + pass = false; break; } } @@ -677,7 +677,7 @@ int main() { float4x3ClampNormExpect[i].w - precision || float4x3ClampNormOutput[4 * i + 3] > float4x3ClampNormExpect[i].w + precision)) { - pass = false; // TODO: need support. + pass = false; break; } } @@ -743,7 +743,7 @@ int main() { float4x3MirrorNormExpect[i].w - precision || float4x3MirrorNormOutput[4 * i + 3] > float4x3MirrorNormExpect[i].w + precision)) { - pass = false; // TODO: need support. + pass = false; break; } } @@ -810,7 +810,7 @@ int main() { float4x3BorderNormExpect[i].w - precision || float4x3BorderNormOutput[4 * i + 3] > float4x3BorderNormExpect[i].w + precision)) { - pass = false; // TODO: need support. + pass = false; break; } } @@ -893,7 +893,7 @@ int main() { float4x3WrapLinearNormExpect[i].w - precision || float4x3WrapLinearNormOutput[4 * i + 3] > float4x3WrapLinearNormExpect[i].w + precision)) { - pass = false; // TODO: need support. + pass = false; break; } } @@ -984,7 +984,7 @@ int main() { float4x3ClampLinearNormExpect[i].w - precision || float4x3ClampLinearNormOutput[4 * i + 3] > float4x3ClampLinearNormExpect[i].w + precision)) { - pass = false; // TODO: need support. + pass = false; break; } } @@ -1076,7 +1076,7 @@ int main() { float4x3MirrorLinearNormExpect[i].w - precision || float4x3MirrorLinearNormOutput[4 * i + 3] > float4x3MirrorLinearNormExpect[i].w + precision)) { - pass = false; // TODO: need support. + pass = false; break; } } @@ -1169,7 +1169,7 @@ int main() { float4x3BorderLinearNormExpect[i].w - precision || float4x3BorderLinearNormOutput[4 * i + 3] > float4x3BorderLinearNormExpect[i].w + precision)) { - pass = false; // TODO: need support. + pass = false; break; } } diff --git a/behavior_tests/src/bt-yaml-with-different-options2/test.cu b/features/feature_case/misc/assert.cu similarity index 64% rename from behavior_tests/src/bt-yaml-with-different-options2/test.cu rename to features/feature_case/misc/assert.cu index aa80fe47c..bd8718f4e 100644 --- a/behavior_tests/src/bt-yaml-with-different-options2/test.cu +++ b/features/feature_case/misc/assert.cu @@ -1,4 +1,4 @@ -// ====------ test.cu---------- *- CUDA -* ----===//// +// ====------ assert.cu---------- *- CUDA -* ----===//// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,9 +6,12 @@ // // // ===----------------------------------------------------------------------===// -#include + +__global__ void kernel() { + __assertfail("", __FILE__, __LINE__, __func__, sizeof(char)); +} int main() { - float2 f2; + kernel<<<1, 64>>>(); return 0; -} +} \ No newline at end of file diff --git a/features/features.xml b/features/features.xml index a43678f66..49826a9c5 100644 --- a/features/features.xml +++ b/features/features.xml @@ -309,9 +309,9 @@ - - - + + + @@ -324,5 +324,6 @@ + diff --git a/features/test_feature.py b/features/test_feature.py index 5b276310a..c507e3887 100644 --- a/features/test_feature.py +++ b/features/test_feature.py @@ -42,7 +42,6 @@ 'cudnn-GetErrorString', 'cub_device_histgram', 'peer_access', 'cudnn-types', 'cudnn-version', 'cudnn-dropout', 'const_opt', 'constant_attr', 'sync_warp_p2', 'occupancy_calculation', - 'text_experimental_obj_array', 'text_experimental_obj_linear', 'text_experimental_obj_pitch2d', 'text_obj_array', 'text_obj_linear', 'text_obj_pitch2d', 'match', 'thrust-unique_by_key', 'cufft_test', 'cufft-external-workspace', "pointer_attributes", 'math_intel_specific', 'math-drcp', 'thrust-pinned-allocator', 'driverMem', 'cusolver_test1', 'cusolver_test2', 'cusolver_test3', 'cusolver_test4', 'cusolver_test5', 'thrust_op', 'cublas-extension', 'cublas_v1_runable', 'thrust_minmax_element', @@ -195,6 +194,8 @@ def build_test(): ret = compile_and_link(srcs, cmp_options, objects, link_opts) elif re.match('^cufft.*', test_config.current_test) and platform.system() == 'Linux': ret = compile_and_link(srcs, cmp_options, objects, link_opts) + elif test_config.current_test.startswith('text_experimental_obj_') and test_config.device_filter == "cuda:gpu": + ret = compile_and_link(srcs, cmp_options, objects, link_opts) else: ret = compile_files(srcs, cmp_options) return ret diff --git a/features/config/TEMPLATE_image_skip_cpu_gpu.xml b/help_function/config/TEMPLATE_help_function_skip_cpu_gpu.xml similarity index 59% rename from features/config/TEMPLATE_image_skip_cpu_gpu.xml rename to help_function/config/TEMPLATE_help_function_skip_cpu_gpu.xml index fa8b32e6e..8cb2aae03 100644 --- a/features/config/TEMPLATE_image_skip_cpu_gpu.xml +++ b/help_function/config/TEMPLATE_help_function_skip_cpu_gpu.xml @@ -1,13 +1,12 @@ - - test + + test help function - + - diff --git a/help_function/help_function.xml b/help_function/help_function.xml index abf3e4728..cfbd1411f 100644 --- a/help_function/help_function.xml +++ b/help_function/help_function.xml @@ -81,6 +81,7 @@ + diff --git a/help_function/src/bindless_images.cpp b/help_function/src/bindless_images.cpp new file mode 100644 index 000000000..d7a2ecab7 --- /dev/null +++ b/help_function/src/bindless_images.cpp @@ -0,0 +1,70 @@ +// ===------------ bindless_images.cpp ---------- -*- C++ -* --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// +// ===---------------------------------------------------------------------===// + +#include + +dpct::experimental::bindless_image_wrapper w; + +int main() { + auto q = dpct::get_default_queue(); + sycl::ext::oneapi::experimental::image_descriptor d( + {1, 1}, sycl::image_channel_order::rgba, sycl::image_channel_type::fp32); + dpct::experimental::image_mem_ptr mem = + new sycl::ext::oneapi::experimental::image_mem(d, q); + sycl::float4 *data = + (sycl::float4 *)dpct::dpct_malloc(sizeof(sycl::float4), q); + dpct::image_data img(mem); + img.set_data(mem); + size_t pitch; + sycl::float4 *data2d = + (sycl::float4 *)dpct::dpct_malloc(pitch, sizeof(sycl::float4), 1); + dpct::sampling_info samp; + dpct::image_channel chan = dpct::image_channel::create(); + printf("prepare variable pass!\n"); + dpct::experimental::async_dpct_memcpy(mem, 0, 0, data, 1, 1, 1, q); + dpct::experimental::dpct_memcpy(mem, 0, 0, data, 1, 1, 1, q); + dpct::experimental::async_dpct_memcpy(mem, 0, 0, data, 1, q); + dpct::experimental::dpct_memcpy(mem, 0, 0, data, 1, q); + dpct::experimental::async_dpct_memcpy(data, mem, 0, 0, 1, 1, 1, q); + dpct::experimental::dpct_memcpy(data, mem, 0, 0, 1, 1, 1, q); + dpct::experimental::async_dpct_memcpy(data, mem, 0, 0, 1, q); + dpct::experimental::dpct_memcpy(data, mem, 0, 0, 1, q); + dpct::experimental::dpct_memcpy(mem, 0, 0, mem, 0, 0, 1, 1, q); + dpct::experimental::dpct_memcpy(mem, 0, 0, mem, 0, 0, 1, q); + printf("memory copy pass!\n"); + auto h = dpct::experimental::create_bindless_image(img, samp, q); + dpct::experimental::get_data(h); + dpct::experimental::get_sampling_info(h); + dpct::experimental::get_channel(mem); + dpct::experimental::destroy_bindless_image(h, q); + printf("texture object pass!\n"); + w.attach(data, 1, chan, q); + w.attach(data, 1, q); + w.attach(data2d, 1, 1, pitch, chan, q); + w.attach(data2d, 1, 1, pitch, chan, q); + w.attach(mem, chan, q); + w.attach(mem, q); + w.detach(q); + w.set_channel(chan); + w.get_channel(); + w.set_channel_size(4, 1); + w.get_channel_size(); + w.set_channel_data_type(dpct::image_channel_data_type::fp); + w.get_channel_data_type(); + w.set(sycl::addressing_mode::repeat); + w.get_addressing_mode(); + w.set(sycl::coordinate_normalization_mode::unnormalized); + w.is_coordinate_normalized(); + w.set(sycl::filtering_mode::nearest); + w.get_filtering_mode(); + w.get_handle(); + printf("texture referece pass!\n"); + printf("test pass!\n"); + return 0; +} diff --git a/run_test.py b/run_test.py index 009017f33..5a4fc7128 100644 --- a/run_test.py +++ b/run_test.py @@ -485,6 +485,7 @@ def main(): args = parse_input_args() do_sanity_test() set_default_compiler(args.option == 'option_cuda_backend') + set_default_c_compiler(args.option == 'option_cuda_backend') suite_list = get_suite_list() test_config.root_path = os.getcwd() test_config.option_map = get_option_mapping() diff --git a/test_config.py b/test_config.py index 394b19dec..896bf3910 100644 --- a/test_config.py +++ b/test_config.py @@ -9,6 +9,7 @@ # The arguments DPCXX_COM = "" # Default compiler will set by set_default_compiler API call. +C_COM = "" # Default c compiler will set by set_default_c_compiler API call. CT_TOOL = "dpct" # The migration tool binary name PYTHON_COM = "python3 " suite_list_file = "test_suite_list.xml" # The configuration file lists all the suite to run and corresponding options. diff --git a/test_suite_list.xml b/test_suite_list.xml index fef1b7a97..157354986 100644 --- a/test_suite_list.xml +++ b/test_suite_list.xml @@ -10,5 +10,6 @@ + diff --git a/test_utils.py b/test_utils.py index e043e1e3e..833e305cd 100644 --- a/test_utils.py +++ b/test_utils.py @@ -63,6 +63,15 @@ def set_default_compiler(use_cuda : bool): else: test_config.DPCXX_COM = "icpx -fsycl" +def set_default_c_compiler(use_cuda : bool): + if (use_cuda): + test_config.C_COM = 'clang' + return + if (platform.system() == 'Windows'): + test_config.C_COM = "icx-cc" + else: + test_config.C_COM = "icx" + def print_debug_log(desc, *args): if (test_config.VERBOSE_LEVEL != 0): print(desc + " : " ) @@ -72,10 +81,13 @@ def print_debug_log(desc, *args): def compile_files(srcs, cmpopts = []): ret = True - base_cmd = test_config.DPCXX_COM + " -c " - if (platform.system() == 'Windows'): - base_cmd += " /EHsc -DNOMINMAX " for src in srcs: + if os.path.splitext(src)[1] == '.c': + base_cmd = test_config.C_COM + " -c " + else: + base_cmd = test_config.DPCXX_COM + " -c " + if (platform.system() == 'Windows'): + base_cmd += " /EHsc -DNOMINMAX " cmd = base_cmd + src + ' ' + ' '.join(cmpopts) ret = call_subprocess(cmd) and ret return ret