diff --git a/package.json b/package.json index b4c3e6c..b18022a 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "lycoris", "private": true, - "version": "0.9.21", + "version": "0.9.22", "type": "module", "license": "MIT", "engines": { @@ -44,6 +44,6 @@ "postcss": "^8.4.41", "tailwindcss": "^3.4.10", "typescript": "^5.5.4", - "vite": "^5.4.2" + "vite": "^5.4.6" } } diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index b6675b9..f05c24c 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -27,6 +27,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "ahash" version = "0.8.11" @@ -170,9 +176,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.83" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25bdb32cbbdce2b519a9cd7df3a678443100e265d5e25ca763b7572a5104f5f3" +checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" [[package]] name = "arbitrary" @@ -220,6 +226,12 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +[[package]] +name = "as-any" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b8a30a44e99a1c83ccb2a6298c563c888952a1c9134953db26876528f84c93a" + [[package]] name = "async-broadcast" version = "0.5.1" @@ -469,7 +481,7 @@ dependencies = [ "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.7.2", "object", "rustc-demangle", ] @@ -719,11 +731,12 @@ dependencies = [ [[package]] name = "candle-core" -version = "0.6.0" -source = "git+https://github.com/EricLBuehler/candle.git?rev=3c8e120#3c8e120e8a6ae88b41ee251ec6255035864858d6" +version = "0.7.2" +source = "git+https://github.com/EricLBuehler/candle.git?rev=628775#6287750d26e2a9ed6e5f4f4774c51e6af109536c" dependencies = [ "byteorder", "candle-metal-kernels", + "float8", "gemm", "half", "memmap2", @@ -741,8 +754,8 @@ dependencies = [ [[package]] name = "candle-metal-kernels" -version = "0.6.0" -source = "git+https://github.com/EricLBuehler/candle.git?rev=3c8e120#3c8e120e8a6ae88b41ee251ec6255035864858d6" +version = "0.7.2" +source = "git+https://github.com/EricLBuehler/candle.git?rev=628775#6287750d26e2a9ed6e5f4f4774c51e6af109536c" dependencies = [ "metal", "once_cell", @@ -752,8 +765,8 @@ dependencies = [ [[package]] name = "candle-nn" -version = "0.6.0" -source = "git+https://github.com/EricLBuehler/candle.git?rev=3c8e120#3c8e120e8a6ae88b41ee251ec6255035864858d6" +version = "0.7.2" +source = "git+https://github.com/EricLBuehler/candle.git?rev=628775#6287750d26e2a9ed6e5f4f4774c51e6af109536c" dependencies = [ "candle-core", "candle-metal-kernels", @@ -903,6 +916,7 @@ dependencies = [ "anstyle", "clap_lex", "strsim 0.11.1", + "terminal_size", ] [[package]] @@ -947,9 +961,9 @@ dependencies = [ [[package]] name = "cmake" -version = "0.1.50" +version = "0.1.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130" +checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a" dependencies = [ "cc", ] @@ -1363,16 +1377,16 @@ dependencies = [ [[package]] name = "ct2rs" -version = "0.7.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a02ec8476ffc2a7330384ae07bc69e05d3a9366167e6479d9b827818e677977e" +checksum = "d755bc515349d0a0a7c1b61bc2439a7afbf01be00a3139af73618045a6ea3614" dependencies = [ "anyhow", "cmake", "cxx", "cxx-build", "sentencepiece", - "tokenizers", + "tokenizers 0.20.1", "walkdir", ] @@ -1388,9 +1402,9 @@ dependencies = [ [[package]] name = "cxx" -version = "1.0.122" +version = "1.0.129" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb497fad022245b29c2a0351df572e2d67c1046bcef2260ebc022aec81efea82" +checksum = "cbdc8cca144dce1c4981b5c9ab748761619979e515c3d53b5df385c677d1d007" dependencies = [ "cc", "cxxbridge-flags", @@ -1400,9 +1414,9 @@ dependencies = [ [[package]] name = "cxx-build" -version = "1.0.122" +version = "1.0.129" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9327c7f9fbd6329a200a5d4aa6f674c60ab256525ff0084b52a889d4e4c60cee" +checksum = "c5764c3142ab44fcf857101d12c0ddf09c34499900557c764f5ad0597159d1fc" dependencies = [ "cc", "codespan-reporting", @@ -1415,15 +1429,15 @@ dependencies = [ [[package]] name = "cxxbridge-flags" -version = "1.0.122" +version = "1.0.129" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "688c799a4a846f1c0acb9f36bb9c6272d9b3d9457f3633c7753c6057270df13c" +checksum = "d422aff542b4fa28c2ce8e5cc202d42dbf24702345c1fba3087b2d3f8a1b90ff" [[package]] name = "cxxbridge-macro" -version = "1.0.122" +version = "1.0.129" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "928bc249a7e3cd554fd2e8e08a426e9670c50bbfc9a621653cfa9accc9641783" +checksum = "a1719100f31492cd6adeeab9a0f46cdbc846e615fdb66d7b398aa46ec7fdd06f" dependencies = [ "proc-macro2", "quote", @@ -1679,6 +1693,17 @@ dependencies = [ "syn 2.0.61", ] +[[package]] +name = "derive-new" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.61", +] + [[package]] name = "derive_arbitrary" version = "1.3.2" @@ -1904,6 +1929,70 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +[[package]] +name = "encoding" +version = "0.2.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b0d943856b990d12d3b55b359144ff341533e516d94098b1d3fc1ac666d36ec" +dependencies = [ + "encoding-index-japanese", + "encoding-index-korean", + "encoding-index-simpchinese", + "encoding-index-singlebyte", + "encoding-index-tradchinese", +] + +[[package]] +name = "encoding-index-japanese" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04e8b2ff42e9a05335dbf8b5c6f7567e5591d0d916ccef4e0b1710d32a0d0c91" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-korean" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dc33fb8e6bcba213fe2f14275f0963fd16f0a02c878e3095ecfdf5bee529d81" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-simpchinese" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d87a7194909b9118fc707194baa434a4e3b0fb6a5a757c73c3adb07aa25031f7" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-singlebyte" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3351d5acffb224af9ca265f435b859c7c01537c0849754d3db3fdf2bfe2ae84a" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding-index-tradchinese" +version = "1.20141219.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd0e20d5688ce3cab59eb3ef3a2083a5c77bf496cb798dc6fcdb75f323890c18" +dependencies = [ + "encoding_index_tests", +] + +[[package]] +name = "encoding_index_tests" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a246d82be1c9d791c5dfde9a2bd045fc3cbba3fa2b11ad558f27d01712f00569" + [[package]] name = "encoding_rs" version = "0.8.33" @@ -1913,6 +2002,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "encoding_rs_io" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cc3c5651fb62ab8aa3103998dade57efdd028544bd300516baa31840c252a83" +dependencies = [ + "encoding_rs", +] + [[package]] name = "endian-type" version = "0.1.2" @@ -1952,6 +2050,29 @@ dependencies = [ "syn 2.0.61", ] +[[package]] +name = "env_filter" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "humantime", + "log", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -2063,7 +2184,7 @@ dependencies = [ "flume", "half", "lebe", - "miniz_oxide", + "miniz_oxide 0.7.2", "rayon-core", "smallvec", "zune-inflate", @@ -2141,12 +2262,24 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flate2" -version = "1.0.28" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" +checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253" dependencies = [ "crc32fast", - "miniz_oxide", + "miniz_oxide 0.8.0", +] + +[[package]] +name = "float8" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c3475274d374d263c4c40c43ad854c5bdf733c7db775bbd3c1ca2ad7427978" +dependencies = [ + "half", + "num-traits", + "rand 0.8.5", + "rand_distr", ] [[package]] @@ -3075,6 +3208,12 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "hyper" version = "0.14.28" @@ -3466,6 +3605,15 @@ dependencies = [ "system-deps 5.0.0", ] +[[package]] +name = "jlabel" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "145ee6f495871a0cde6d49ddfa0d103d07430c449d95b6d92fbfb032d622f0b7" +dependencies = [ + "thiserror", +] + [[package]] name = "jni" version = "0.19.0" @@ -3518,6 +3666,124 @@ dependencies = [ "rayon", ] +[[package]] +name = "jpreprocess" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05ad4316553f214144e04abb54a95f0ec55d9b8b5c4ae004f420ead40d07fe4" +dependencies = [ + "jlabel", + "jpreprocess-core", + "jpreprocess-dictionary", + "jpreprocess-dictionary-builder", + "jpreprocess-jpcommon", + "jpreprocess-naist-jdic", + "jpreprocess-njd", + "lindera-core", + "lindera-dictionary", + "lindera-tokenizer", + "phf 0.11.2", +] + +[[package]] +name = "jpreprocess-core" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe30c65ff4c092320f1bba3418ac111443a4827a9155442f6a7d8d0e3707cb51" +dependencies = [ + "aho-corasick", + "bincode", + "lindera-core", + "lindera-tokenizer", + "once_cell", + "regex", + "serde", + "thiserror", +] + +[[package]] +name = "jpreprocess-dictionary" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c156a875d35ef6fedf31cb9d6bb3c562d16faad4506a5be27e2ed44357d755d4" +dependencies = [ + "anyhow", + "bincode", + "byteorder", + "jpreprocess-core", + "lindera-core", + "lindera-ipadic-builder", + "lindera-tokenizer", + "once_cell", + "serde", +] + +[[package]] +name = "jpreprocess-dictionary-builder" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cd89b119949a5071e6f49d805829f3dc17169eb7c6aab809e4f373c70098709" +dependencies = [ + "anyhow", + "bincode", + "byteorder", + "csv", + "glob", + "jpreprocess-core", + "jpreprocess-dictionary", + "lindera-core", + "log", + "rayon", + "serde", + "yada", +] + +[[package]] +name = "jpreprocess-jpcommon" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a95684847dcf9a95a94d74f725ca207e5106f2c0084959d9b328b7f8fcf3184" +dependencies = [ + "jlabel", + "jpreprocess-core", + "jpreprocess-njd", +] + +[[package]] +name = "jpreprocess-naist-jdic" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "472767e7dc48354e3b42759721ffc6b457856a339b10f9cb039749011f570cc2" +dependencies = [ + "encoding", + "flate2", + "jpreprocess-dictionary", + "jpreprocess-dictionary-builder", + "lindera-core", + "tar", + "ureq", +] + +[[package]] +name = "jpreprocess-njd" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25fc5071bd75e17af650bfe4697dfe7f5af8254965ce3242476698cee3c3b7af" +dependencies = [ + "aho-corasick", + "jpreprocess-core", + "jpreprocess-dictionary", + "jpreprocess-window", + "lindera-tokenizer", + "phf 0.11.2", +] + +[[package]] +name = "jpreprocess-window" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c06d7aceb8ce626a3318183096aa6dad82f046b3cec5d43e90066d1b07445a2" + [[package]] name = "js-sys" version = "0.3.69" @@ -3658,6 +3924,149 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "lindera-cc-cedict-builder" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85b8f642bc9c9130682569975772a17336c6aab26d11fc0f823f3e663167ace6" +dependencies = [ + "anyhow", + "lindera-core", + "lindera-decompress", + "lindera-dictionary-builder", +] + +[[package]] +name = "lindera-core" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c28191456debc98af6aa5f7db77872471983e9fa2a737b1c232b6ef543aed62" +dependencies = [ + "anyhow", + "bincode", + "byteorder", + "encoding_rs", + "log", + "once_cell", + "serde", + "thiserror", + "yada", +] + +[[package]] +name = "lindera-decompress" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4788a1ead2f63f3fc2888109272921dedd86a87b7d0bf05e9daab46600daac51" +dependencies = [ + "anyhow", + "flate2", + "serde", +] + +[[package]] +name = "lindera-dictionary" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdf5f91725e32b9a21b1656baa7030766c9bafc4de4b4ddeb8ffdde7224dd2f6" +dependencies = [ + "anyhow", + "bincode", + "byteorder", + "lindera-cc-cedict-builder", + "lindera-core", + "lindera-ipadic-builder", + "lindera-ipadic-neologd-builder", + "lindera-ko-dic-builder", + "lindera-unidic-builder", + "serde", + "strum", + "strum_macros", +] + +[[package]] +name = "lindera-dictionary-builder" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e41f00ba7ac541b0ffd8c30e7a73f2dd197546cc5780462ec4f2e4782945a780" +dependencies = [ + "anyhow", + "bincode", + "byteorder", + "csv", + "derive_builder", + "encoding", + "encoding_rs", + "encoding_rs_io", + "glob", + "lindera-core", + "lindera-decompress", + "log", + "yada", +] + +[[package]] +name = "lindera-ipadic-builder" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf5031c52686128db13f774b2c5a8abfd52b4cc1f904041d8411aa19d630ce4d" +dependencies = [ + "anyhow", + "lindera-core", + "lindera-decompress", + "lindera-dictionary-builder", +] + +[[package]] +name = "lindera-ipadic-neologd-builder" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abf36e40ace904741efdd883ed5c4dba6425f65156a0fb5d3f73a386335950dc" +dependencies = [ + "anyhow", + "lindera-core", + "lindera-decompress", + "lindera-dictionary-builder", +] + +[[package]] +name = "lindera-ko-dic-builder" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f2c60425abc1548570c2568858f74a1f042105ecd89faa39c651b4315350fd9" +dependencies = [ + "anyhow", + "lindera-core", + "lindera-decompress", + "lindera-dictionary-builder", +] + +[[package]] +name = "lindera-tokenizer" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "903e558981bcb6f59870aa7d6b4bcb09e8f7db778886a6a70f67fd74c9fa2ca3" +dependencies = [ + "bincode", + "lindera-core", + "lindera-dictionary", + "once_cell", + "serde", + "serde_json", +] + +[[package]] +name = "lindera-unidic-builder" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99e2c50015c242e02c451acb6748667ac6fd1d3d667cd7db48cd89e2f2d2377e" +dependencies = [ + "anyhow", + "lindera-core", + "lindera-decompress", + "lindera-dictionary-builder", +] + [[package]] name = "line-wrap" version = "0.2.0" @@ -3740,7 +4149,7 @@ dependencies = [ [[package]] name = "lycoris" -version = "0.9.21" +version = "0.9.22" dependencies = [ "chrono", "core-graphics 0.24.0", @@ -3759,6 +4168,7 @@ dependencies = [ "reqwest 0.12.5", "rusqlite", "samplerate-rs", + "sbv2_core", "screencapturekit", "serde", "serde_json", @@ -3770,7 +4180,6 @@ dependencies = [ "tokio", "unicode-segmentation", "urlencoding", - "uuid 1.10.0", "vosk", "whisper-rs", "xcap", @@ -3868,6 +4277,16 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5" +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "maybe-rayon" version = "0.1.1" @@ -3998,6 +4417,15 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] + [[package]] name = "mio" version = "0.8.11" @@ -4012,8 +4440,8 @@ dependencies = [ [[package]] name = "mistralrs" -version = "0.3.0" -source = "git+https://github.com/EricLBuehler/mistral.rs.git?tag=v0.3.0#ae71578be27f4369a4d9a0c7d9b849be14c82162" +version = "0.3.1" +source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=32e8945#32e894510696e9aa3c11db79268ee031a3ecefa6" dependencies = [ "anyhow", "candle-core", @@ -4031,11 +4459,12 @@ dependencies = [ [[package]] name = "mistralrs-core" -version = "0.3.0" -source = "git+https://github.com/EricLBuehler/mistral.rs.git?tag=v0.3.0#ae71578be27f4369a4d9a0c7d9b849be14c82162" +version = "0.3.1" +source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=32e8945#32e894510696e9aa3c11db79268ee031a3ecefa6" dependencies = [ "akin", "anyhow", + "as-any", "async-trait", "base64 0.22.1", "buildstructor", @@ -4047,10 +4476,11 @@ dependencies = [ "chrono", "clap", "csv", - "derive-new", + "derive-new 0.7.0", "derive_more", "dirs", "either", + "float8", "futures", "galil-seiferas", "half", @@ -4073,14 +4503,16 @@ dependencies = [ "regex-automata 0.4.6", "reqwest 0.12.5", "rustc-hash 2.0.0", + "safetensors", "schemars", "serde", "serde_json", + "serde_plain", "serde_yaml", "strum", "sysinfo", "thiserror", - "tokenizers", + "tokenizers 0.19.1", "tokio", "tokio-rayon", "toml 0.8.12", @@ -4094,13 +4526,16 @@ dependencies = [ [[package]] name = "mistralrs-quant" -version = "0.3.0" -source = "git+https://github.com/EricLBuehler/mistral.rs.git?tag=v0.3.0#ae71578be27f4369a4d9a0c7d9b849be14c82162" +version = "0.3.1" +source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=32e8945#32e894510696e9aa3c11db79268ee031a3ecefa6" dependencies = [ + "byteorder", "candle-core", "candle-nn", + "float8", "half", "lazy_static", + "once_cell", "paste", "rayon", "serde", @@ -4109,8 +4544,8 @@ dependencies = [ [[package]] name = "mistralrs-vision" -version = "0.3.0" -source = "git+https://github.com/EricLBuehler/mistral.rs.git?tag=v0.3.0#ae71578be27f4369a4d9a0c7d9b849be14c82162" +version = "0.3.1" +source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=32e8945#32e894510696e9aa3c11db79268ee031a3ecefa6" dependencies = [ "candle-core", "image 0.25.2", @@ -4155,6 +4590,21 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "ndk" version = "0.6.0" @@ -4647,6 +5097,32 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "ort" +version = "2.0.0-rc.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5f95fe501e1cb81dec2f66ee3129025759b602817aa2c77ff421390c418cc34" +dependencies = [ + "half", + "libloading", + "ndarray", + "ort-sys", + "tracing", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4174960a7b93a17564a05b26e05889f0dea9ee70e68db5841f27b40c0c9804e" +dependencies = [ + "flate2", + "pkg-config", + "sha2", + "tar", + "ureq", +] + [[package]] name = "os_info" version = "3.8.2" @@ -5003,7 +5479,7 @@ dependencies = [ "crc32fast", "fdeflate", "flate2", - "miniz_oxide", + "miniz_oxide 0.7.2", ] [[package]] @@ -5043,6 +5519,15 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" +[[package]] +name = "portable-atomic-util" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcdd8420072e66d54a407b3316991fe946ce3ab1083a7f575b2463866624704d" +dependencies = [ + "portable-atomic", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -5063,9 +5548,9 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" [[package]] name = "prettyplease" -version = "0.2.17" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7" +checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" dependencies = [ "proc-macro2", "syn 2.0.61", @@ -5391,6 +5876,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9" +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.10.0" @@ -5803,9 +6294,9 @@ checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" [[package]] name = "safetensors" -version = "0.4.3" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ced76b22c7fba1162f11a5a75d9d8405264b467a07ae0c9c29be119b9297db9" +checksum = "44560c11236a6130a46ce36c836a62936dc81ebf8c36a37947423571be0e55b6" dependencies = [ "serde", "serde_json", @@ -5829,6 +6320,30 @@ dependencies = [ "libsamplerate", ] +[[package]] +name = "sbv2_core" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa89a844f9b5e6e7bf8353bfca5f6f414c12c2cfcffa86e17da7d58d839208c0" +dependencies = [ + "anyhow", + "dotenvy", + "env_logger", + "hound", + "jpreprocess", + "ndarray", + "num_cpus", + "once_cell", + "ort", + "regex", + "serde", + "serde_json", + "tar", + "thiserror", + "tokenizers 0.20.1", + "zstd", +] + [[package]] name = "schannel" version = "0.1.23" @@ -6002,18 +6517,18 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.197" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" +checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.197" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" +checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" dependencies = [ "proc-macro2", "quote", @@ -6033,16 +6548,26 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.114" +version = "1.0.128" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" dependencies = [ "indexmap 2.2.6", "itoa 1.0.10", + "memchr", "ryu", "serde", ] +[[package]] +name = "serde_plain" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1fc6db65a611022b23a0dec6975d63fb80a302cb3388835ff02c097258d50" +dependencies = [ + "serde", +] + [[package]] name = "serde_repr" version = "0.1.18" @@ -6308,6 +6833,17 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "soup2" version = "0.2.1" @@ -6941,9 +7477,9 @@ dependencies = [ [[package]] name = "tar" -version = "0.4.40" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb" +checksum = "cb797dad5fb5b76fcf519e702f4a589483b5ef06567f160c392832c1f5e44909" dependencies = [ "filetime", "libc", @@ -7243,6 +7779,16 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "terminal_size" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21bebf2b7c9e0a515f6e0f8c51dc0f8e4696391e6f1ff30379559f8365fb0df7" +dependencies = [ + "rustix 0.38.32", + "windows-sys 0.48.0", +] + [[package]] name = "thin-slice" version = "0.1.1" @@ -7251,18 +7797,18 @@ checksum = "8eaa81235c7058867fa8c0e7314f33dcce9c215f535d1913822a2b3f5e289f3c" [[package]] name = "thiserror" -version = "1.0.58" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" +checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.58" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" +checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", @@ -7368,6 +7914,38 @@ dependencies = [ "unicode_categories", ] +[[package]] +name = "tokenizers" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b172ffa9a2e5c31bbddc940cd5725d933ced983a9333bbebc4c7eda3bbce1557" +dependencies = [ + "aho-corasick", + "derive_builder", + "esaxx-rs", + "getrandom 0.2.12", + "indicatif", + "itertools 0.12.1", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand 0.8.5", + "rayon", + "rayon-cond", + "regex", + "regex-syntax 0.8.2", + "serde", + "serde_json", + "spm_precompiled", + "thiserror", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.36.0" @@ -7770,6 +8348,7 @@ dependencies = [ "rustls-webpki", "serde", "serde_json", + "socks", "url", "webpki-roots", ] @@ -8244,18 +8823,16 @@ dependencies = [ [[package]] name = "whisper-rs" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed7a191078e189d96d029244ab1dff159775adec71dc89a222e9bb9d21a7d161" +version = "0.13.1" +source = "git+https://github.com/arizhih/whisper-rs.git?branch=whisper-cpp-1-7-x#9f56a7350b31ce09f6e95be928beeea837684a59" dependencies = [ "whisper-rs-sys", ] [[package]] name = "whisper-rs-sys" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2834f4ca6472b02748c6c282a60cea538f962428f79e685c3205a08efd711336" +version = "0.11.1" +source = "git+https://github.com/arizhih/whisper-rs.git?branch=whisper-cpp-1-7-x#9f56a7350b31ce09f6e95be928beeea837684a59" dependencies = [ "bindgen", "cfg-if", @@ -8757,7 +9334,7 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "12b41773911497b18ca8553c3daaf8ec9fe9819caf93d451d3055f69de028adb" dependencies = [ - "derive-new", + "derive-new 0.6.0", "libc", "log", "nix 0.28.0", @@ -8897,6 +9474,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "yada" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aed111bd9e48a802518765906cbdadf0b45afb72b9c81ab049a3b86252adffdd" + [[package]] name = "yoke" version = "0.7.4" @@ -9049,6 +9632,34 @@ dependencies = [ "thiserror", ] +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.13+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "zune-core" version = "0.4.12" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 9db86f0..de75e48 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lycoris" -version = "0.9.21" +version = "0.9.22" description = "Lycoris is an offline voice memo" authors = ["solaoi"] license = "MIT" @@ -43,18 +43,15 @@ core-graphics = "0.24.0" objc = "0.2" objc-foundation = "0.1" objc_id = "0.1" -ct2rs = { version = "0.7.3", features = ["accelerate"] } -mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", tag = "v0.3.0", features = [ +ct2rs = { version = "0.9.4", features = ["accelerate"] } +mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "32e8945", features = [ "metal", ] } -# mistralrsの要求がtauriの要求と競合するため、tauriの要求を上書きする -uuid = "=1.10.0" tauri-plugin-clipboard = "1.1.4" - -[target.'cfg(target_arch = "x86_64")'.dependencies] -whisper-rs = { version = "0.11.1", features = ["metal"] } -[target.'cfg(target_arch = "aarch64")'.dependencies] -whisper-rs = { version = "0.11.1", features = ["metal", "coreml"] } +sbv2_core = { version = "0.1.4", features = ["dynamic"] } +# whisper-rs = { version = "0.11.1", features = ["metal"] } +# whisper-rs = { git = "https://github.com/tazz4843/whisper-rs.git", rev = "67924ca", features = ["metal"] } +whisper-rs = { git = "https://github.com/arizhih/whisper-rs.git", branch = "whisper-cpp-1-7-x", features = ["metal"] } [dependencies.tauri-plugin-sql] git = "https://github.com/tauri-apps/plugins-workspace" diff --git a/src-tauri/lib/libonnxruntime.dylib b/src-tauri/lib/libonnxruntime.dylib new file mode 100755 index 0000000..035534e Binary files /dev/null and b/src-tauri/lib/libonnxruntime.dylib differ diff --git a/src-tauri/migrations/001.sql b/src-tauri/migrations/001.sql index 1e919c2..440d830 100644 --- a/src-tauri/migrations/001.sql +++ b/src-tauri/migrations/001.sql @@ -8,6 +8,8 @@ CREATE TABLE speeches ( id INTEGER PRIMARY KEY AUTOINCREMENT, speech_type TEXT, -- speech|memo|screenshot|action + action_type TEXT, + -- chat|suggest created_at_unixtime INTEGER DEFAULT (CAST(strftime('%s', 'now') AS INTEGER)), content TEXT, content_2 TEXT, @@ -35,7 +37,7 @@ VALUES("settingKeyAmivoice", ""); INSERT INTO settings(setting_name, setting_status) VALUES("settingLanguage", "日本語"); INSERT INTO settings(setting_name, setting_status) -VALUES("settingProcess", "文字起こし"); +VALUES("settingProcess", "文字起こし(汎用)"); INSERT INTO settings(setting_name, setting_status) VALUES("settingOnline", "OpenAI"); INSERT INTO settings(setting_name, setting_status) @@ -74,10 +76,14 @@ VALUES("base.en", "whisper"); INSERT INTO models(model_name, model_type) VALUES("large", "whisper"); INSERT INTO models(model_name, model_type) +VALUES("large-turbo", "whisper"); +INSERT INTO models(model_name, model_type) VALUES("large-distil.en", "whisper"); INSERT INTO models(model_name, model_type) VALUES("large-distil.ja", "whisper"); INSERT INTO models(model_name, model_type) +VALUES("large-distil.bilingual", "whisper"); +INSERT INTO models(model_name, model_type) VALUES("medium", "whisper"); INSERT INTO models(model_name, model_type) VALUES("medium.en", "whisper"); @@ -150,6 +156,24 @@ VALUES("small-cs-0.4-rhasspy", "vosk"); INSERT INTO models(model_name, model_type) VALUES("small-pl-0.22", "vosk"); INSERT INTO models(model_name, model_type) -VALUES("fugumt-en-ja", "fugumt"); +VALUES("fugumt-en-ja", "fugumt-en-ja"); +INSERT INTO models(model_name, model_type) +VALUES("fugumt-ja-en", "fugumt-ja-en"); +INSERT INTO models(model_name, model_type) +VALUES("honyaku-13b", "honyaku-13b"); +INSERT INTO models(model_name, model_type) +VALUES("style-bert-vits2", "style-bert-vits2"); +INSERT INTO models(model_name, model_type) +VALUES("tsukuyomi-chan", "style-bert-vits2-voice"); +INSERT INTO models(model_name, model_type) +VALUES("amitaro", "style-bert-vits2-voice"); +INSERT INTO models(model_name, model_type) +VALUES("koharune-ami", "style-bert-vits2-voice"); +INSERT INTO models(model_name, model_type) +VALUES("jvnv-F1-jp", "style-bert-vits2-voice"); +INSERT INTO models(model_name, model_type) +VALUES("jvnv-F2-jp", "style-bert-vits2-voice"); +INSERT INTO models(model_name, model_type) +VALUES("jvnv-M1-jp", "style-bert-vits2-voice"); INSERT INTO models(model_name, model_type) -VALUES("honyaku13b-q4-0", "honyaku13b"); \ No newline at end of file +VALUES("jvnv-M2-jp", "style-bert-vits2-voice"); \ No newline at end of file diff --git a/src-tauri/resources/whisper/ggml-metal.metal b/src-tauri/resources/whisper/ggml-metal.metal deleted file mode 100644 index a7d3f9e..0000000 --- a/src-tauri/resources/whisper/ggml-metal.metal +++ /dev/null @@ -1,5136 +0,0 @@ -#include - -using namespace metal; - -#define MAX(x, y) ((x) > (y) ? (x) : (y)) -#define MIN(x, y) ((x) < (y) ? (x) : (y)) -#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } - -#define QK4_0 32 -#define QR4_0 2 -typedef struct { - half d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants -} block_q4_0; - -#define QK4_1 32 -typedef struct { - half d; // delta - half m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants -} block_q4_1; - -#define QK5_0 32 -typedef struct { - half d; // delta - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_0 / 2]; // nibbles / quants -} block_q5_0; - -#define QK5_1 32 -typedef struct { - half d; // delta - half m; // min - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_1 / 2]; // nibbles / quants -} block_q5_1; - -#define QK8_0 32 -typedef struct { - half d; // delta - int8_t qs[QK8_0]; // quants -} block_q8_0; - -#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 - -enum ggml_sort_order { - GGML_SORT_ASC, - GGML_SORT_DESC, -}; - -// general-purpose kernel for addition, multiplication and division of two tensors -// pros: works for non-contiguous tensors, supports broadcast across all dims -// cons: not very efficient -kernel void kernel_add( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int64_t & offs, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10)); - } -} - -kernel void kernel_mul( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10)); - } -} - -kernel void kernel_div( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10)); - } -} - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_add_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant uint64_t & nb [[buffer(28)]], - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] + src1[tpig % nb]; -} - -kernel void kernel_mul_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant uint64_t & nb [[buffer(28)]], - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src1[tpig % nb]; -} - -kernel void kernel_div_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant uint64_t & nb [[buffer(28)]], - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] / src1[tpig % nb]; -} - -kernel void kernel_scale( - device const float * src0, - device float * dst, - constant float & scale, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; -} - -kernel void kernel_scale_4( - device const float4 * src0, - device float4 * dst, - constant float & scale, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; -} - -kernel void kernel_relu( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); -} - -kernel void kernel_tanh( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = precise::tanh(x); -} - -constant float GELU_COEF_A = 0.044715f; -constant float GELU_QUICK_COEF = -1.702f; -constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - -kernel void kernel_gelu( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - // BEWARE !!! - // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! - // This was observed with Falcon 7B and 40B models - // - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_quick( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -kernel void kernel_silu( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} - -kernel void kernel_sqr( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src0[tpig]; -} - -kernel void kernel_sum_rows( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tpig[[thread_position_in_grid]]) { - int64_t i3 = tpig.z; - int64_t i2 = tpig.y; - int64_t i1 = tpig.x; - - if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { - return; - } - - device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); - device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3); - - float row_sum = 0; - - for (int64_t i0 = 0; i0 < ne00; i0++) { - row_sum += src_row[i0]; - } - - dst_row[0] = row_sum; -} - -kernel void kernel_soft_max( - device const float * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant float & scale, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (ne02*ne01); - const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; - const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - // parallel max - float lmax = -INFINITY; - - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)); - } - - // find the max value in the block - float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = -INFINITY; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = max_val; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - max_val = buf[tiisg]; - max_val = simd_max(max_val); - } - - // parallel sum - float lsum = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); - lsum += exp_psrc0; - pdst[i00] = exp_psrc0; - } - - // This barrier fixes a failing test - // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 - threadgroup_barrier(mem_flags::mem_none); - - float sum = simd_sum(lsum); - - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - sum = buf[tiisg]; - sum = simd_sum(sum); - } - - const float inv_sum = 1.0f/sum; - - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - pdst[i00] *= inv_sum; - } -} - -kernel void kernel_soft_max_4( - device const float * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant float & scale, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (ne02*ne01); - const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; - const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - - // parallel max - float4 lmax4 = -INFINITY; - - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); - } - - const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); - - float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = -INFINITY; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = max_val; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - max_val = buf[tiisg]; - max_val = simd_max(max_val); - } - - // parallel sum - float4 lsum4 = 0.0f; - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); - lsum4 += exp_psrc4; - pdst4[i00] = exp_psrc4; - } - - const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; - - // This barrier fixes a failing test - // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 - threadgroup_barrier(mem_flags::mem_none); - - float sum = simd_sum(lsum); - - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - sum = buf[tiisg]; - sum = simd_sum(sum); - } - - const float inv_sum = 1.0f/sum; - - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - pdst4[i00] *= inv_sum; - } -} - -kernel void kernel_diag_mask_inf( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, - uint3 tpig[[thread_position_in_grid]]) { - const int64_t i02 = tpig[2]; - const int64_t i01 = tpig[1]; - const int64_t i00 = tpig[0]; - - if (i00 > n_past + i01) { - dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; - } else { - dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; - } -} - -kernel void kernel_diag_mask_inf_8( - device const float4 * src0, - device float4 * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, - uint3 tpig[[thread_position_in_grid]]) { - - const int64_t i = 2*tpig[0]; - - dst[i+0] = src0[i+0]; - dst[i+1] = src0[i+1]; - int64_t i4 = 4*i; - const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; - const int64_t i01 = i4/(ne00); i4 -= i01*ne00; - const int64_t i00 = i4; - for (int k = 3; k >= 0; --k) { - if (i00 + 4 + k <= n_past + i01) { - break; - } - dst[i+1][k] = -INFINITY; - if (i00 + k > n_past + i01) { - dst[i][k] = -INFINITY; - } - } -} - -kernel void kernel_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * sum [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); - // MEAN - // parallel sum - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - sum[tpitg] += x[i00]; - } - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float mean = sum[0] / ne00; - - // recenter and VARIANCE - threadgroup_barrier(mem_flags::mem_threadgroup); - device float * y = dst + tgpig*ne00; - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - y[i00] = x[i00] - mean; - sum[tpitg] += y[i00] * y[i00]; - } - - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float variance = sum[0] / ne00; - - const float scale = 1.0f/sqrt(variance + eps); - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - y[i00] = y[i00] * scale; - } -} - -kernel void kernel_rms_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); - - float4 sumf = 0; - float all_sum = 0; - - // parallel sum - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - sumf += x[i00] * x[i00]; - } - all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; - all_sum = simd_sum(all_sum); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = all_sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - all_sum = buf[tiisg]; - all_sum = simd_sum(all_sum); - } - - const float mean = all_sum/ne00; - const float scale = 1.0f/sqrt(mean + eps); - - device float4 * y = (device float4 *) (dst + tgpig*ne00); - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - y[i00] = x[i00] * scale; - } -} - -kernel void kernel_group_norm( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int32_t & n_groups, - constant float & eps, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t ne = ne00*ne01*ne02; - const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups); - - int start = tgpig * gs; - int end = start + gs; - - start += tpitg; - - if (end >= ne) { - end = ne; - } - - float tmp = 0.0f; // partial sum for thread in warp - - for (int j = start; j < end; j += ntg) { - tmp += src0[j]; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - tmp = simd_sum(tmp); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = tmp; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - tmp = buf[tiisg]; - tmp = simd_sum(tmp); - } - - const float mean = tmp / gs; - tmp = 0.0f; - - for (int j = start; j < end; j += ntg) { - float xi = src0[j] - mean; - dst[j] = xi; - tmp += xi * xi; - } - - tmp = simd_sum(tmp); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = tmp; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - tmp = buf[tiisg]; - tmp = simd_sum(tmp); - } - - const float variance = tmp / gs; - const float scale = 1.0f/sqrt(variance + eps); - for (int j = start; j < end; j += ntg) { - dst[j] *= scale; - } -} - -// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (sumy * -8.f + acc[0] + acc[1]); -} - -// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float m = qb_curr->m; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (acc[0] + acc[1]) + sumy * m; -} - -// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q5 quants begin (0 or QK5_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); - const uint32_t qh = *((device const uint32_t *)qb_curr->qh); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); - } - return d * (sumy * -16.f + acc[0] + acc[1]); -} - -// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q5 quants begin (0 or QK5_1/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float m = qb_curr->m; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); - const uint32_t qh = *((device const uint32_t *)qb_curr->qh); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); - } - return d * (acc[0] + acc[1]) + sumy * m; -} - -// putting them in the kernel cause a significant performance penalty -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -//Note: This is a template, but strictly speaking it only applies to -// quantizations where the block size is 32. It also does not -// guard against the number of rows not being divisible by -// N_DST, so this is another explicit assumption of the implementation. -template -void mul_vec_q_n_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, uint tiisg, uint sgitg) { - const int nb = ne00/QK4_0; - - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * nsg + sgitg) * nr; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q_type * x = (device const block_q_type *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; // src1 vector cache - float sumf[nr] = {0.f}; - - const int ix = (tiisg/2); - const int il = (tiisg%2)*8; - - device const float * yb = y + ix * QK4_0 + il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { - float sumy = 0; - for (int i = 0; i < 8; i += 2) { - sumy += yb[i] + yb[i+1]; - yl[i+0] = yb[i+ 0]; - yl[i+1] = yb[i+ 1]/256.f; - - sumy += yb[i+16] + yb[i+17]; - yl[i+8] = yb[i+16]/16.f; - yl[i+9] = yb[i+17]/4096.f; - } - - for (int row = 0; row < nr; row++) { - sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); - } - - yb += QK4_0 * 16; - } - - for (int row = 0; row < nr; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; - } - } -} - -kernel void kernel_mul_mv_q4_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mv_q4_1_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mv_q5_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mv_q5_1_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); -} - - -#define NB_Q8_0 8 - -void kernel_mul_mv_q8_0_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int nr = N_DST; - const int nsg = N_SIMDGROUP; - const int nw = N_SIMDWIDTH; - - const int nb = ne00/QK8_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * nsg + sgitg) * nr; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[NB_Q8_0]; - float sumf[nr]={0.f}; - - const int ix = tiisg/4; - const int il = tiisg%4; - - device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; - - // each thread in a SIMD group deals with NB_Q8_0 quants at a time - for (int ib = ix; ib < nb; ib += nw/4) { - for (int i = 0; i < NB_Q8_0; ++i) { - yl[i] = yb[i]; - } - - for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; - float sumq = 0.f; - for (int iq = 0; iq < NB_Q8_0; ++iq) { - sumq += qs[iq] * yl[iq]; - } - sumf[row] += sumq*x[ib+row*nb].d; - } - - yb += NB_Q8_0 * nw; - } - - for (int row = 0; row < nr; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; - } - } -} - -[[host_name("kernel_mul_mv_q8_0_f32")]] -kernel void kernel_mul_mv_q8_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); -} - -#define N_F32_F32 4 - -void kernel_mul_mv_f32_f32_impl( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F32_F32; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const float * x = (device const float *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const float4 * x4 = (device const float4 *)x; - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -[[host_name("kernel_mul_mv_f32_f32")]] -kernel void kernel_mul_mv_f32_f32( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); -} - -#define N_F16_F16 4 - -kernel void kernel_mul_mv_f16_f16( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F16_F16; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half * x = (device const half *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (half) x[i] * (half) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const half4 * x4 = (device const half4 *)x; - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); - device const half4 * y4 = (device const half4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -void kernel_mul_mv_f16_f32_1row_impl( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half * x = (device const half *) (src0 + offset0); - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - if (ne00 < 128) { - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } else { - device const half4 * x4 = (device const half4 *) x; - device const float4 * y4 = (device const float4 *) y; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; - } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } -} - -[[host_name("kernel_mul_mv_f16_f32_1row")]] -kernel void kernel_mul_mv_f16_f32_1row( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); -} - -#define N_F16_F32 4 - -void kernel_mul_mv_f16_f32_impl( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F16_F32; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half * x = (device const half *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const half4 * x4 = (device const half4 *)x; - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -[[host_name("kernel_mul_mv_f16_f32")]] -kernel void kernel_mul_mv_f16_f32( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); -} - -// Assumes row size (ne00) is a multiple of 4 -kernel void kernel_mul_mv_f16_f32_l4( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int nrows = ne11; - const int64_t r0 = tgpig.x; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half4 * x4 = (device const half4 *) (src0 + offset0); - - for (int r1 = 0; r1 < nrows; ++r1) { - device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } -} - -kernel void kernel_alibi_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant float & m0, - constant float & m1, - constant int & n_heads_log2_floor, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - const int64_t k = i3*ne3 + i2; - - float m_k; - if (k < n_heads_log2_floor) { - m_k = pow(m0, k + 1); - } else { - m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1; - device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01; - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - const float src_v = *(device float *)(src_row + i00*nb00); - device float * dst_v = (device float *)(dst_row + i00*nb0); - *dst_v = i00 * m_k + src_v; - } -} - -static float rope_yarn_ramp(const float low, const float high, const int i0) { - const float y = (i0 / 2 - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} - -// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn -// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static void rope_yarn( - float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, - thread float * cos_theta, thread float * sin_theta -) { - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta = theta_interp; - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); - } - *cos_theta = cos(theta) * mscale; - *sin_theta = sin(theta) * mscale; -} - -// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get -// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { - return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base)); -} - -static void rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] -) { - // start and end correction dims - dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); - dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); -} - -typedef void (rope_t)( - device const void * src0, - device const int32_t * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant int & n_orig_ctx, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]); - -template -kernel void kernel_rope( - device const void * src0, - device const int32_t * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant int & n_orig_ctx, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]) { - const int64_t i3 = tgpig[2]; - const int64_t i2 = tgpig[1]; - const int64_t i1 = tgpig[0]; - - const bool is_neox = mode & 2; - - float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); - - device const int32_t * pos = src1; - - const int64_t p = pos[i2]; - - const float theta_0 = (float)p; - const float inv_ndims = -1.f/n_dims; - - if (!is_neox) { - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - - const float theta = theta_0 * pow(freq_base, inv_ndims*i0); - float cos_theta, sin_theta; - rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const T x0 = src[0]; - const T x1 = src[1]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; - } - } else { - for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { - if (ic < n_dims) { - const int64_t ib = 0; - - // simplified from `(ib * n_dims + ic) * inv_ndims` - const float cur_rot = inv_ndims*ic - ib; - - const float theta = theta_0 * pow(freq_base, cur_rot); - float cos_theta, sin_theta; - rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); - - const int64_t i0 = ib*n_dims + ic/2; - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } else { - const int64_t i0 = ic; - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } - } -} - -template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; -template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; - -kernel void kernel_im2col_f16( - device const float * x, - device half * dst, - constant int32_t & ofs0, - constant int32_t & ofs1, - constant int32_t & IW, - constant int32_t & IH, - constant int32_t & CHW, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int32_t & d0, - constant int32_t & d1, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0; - const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1; - - const int32_t offset_dst = - (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + - (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; - dst[offset_dst] = x[offset_src + iih * IW + iiw]; - } -} - -kernel void kernel_upscale_f32( - device const char * src0, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int32_t & sf, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; - - const int64_t i03 = i3; - const int64_t i02 = i2; - const int64_t i01 = i1/sf; - - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - dst_ptr[i0] = src0_ptr[i0/sf]; - } -} - -kernel void kernel_pad_f32( - device const char * src0, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; - - const int64_t i03 = i3; - const int64_t i02 = i2; - const int64_t i01 = i1; - - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); - - if (i1 < ne01 && i2 < ne02 && i3 < ne03) { - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i0 < ne00) { - dst_ptr[i0] = src0_ptr[i0]; - } else { - dst_ptr[i0] = 0.0f; - } - } - - return; - } - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - dst_ptr[i0] = 0.0f; - } -} - -// bitonic sort implementation following the CUDA kernels as reference -typedef void (argsort_t)( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]]); - -template -kernel void kernel_argsort_f32_i32( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]]) { - // bitonic sort - int col = tpitg[0]; - int row = tgpig[1]; - - if (col >= ncols) return; - - device const float * x_row = x + row * ncols; - device int32_t * dst_row = dst + row * ncols; - - // initialize indices - if (col < ncols) { - dst_row[col] = col; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (int k = 2; k <= ncols; k *= 2) { - for (int j = k / 2; j > 0; j /= 2) { - int ixj = col ^ j; - if (ixj > col) { - if ((col & k) == 0) { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { - SWAP(dst_row[col], dst_row[ixj]); - } - } else { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { - SWAP(dst_row[col], dst_row[ixj]); - } - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - } -} - -template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; -template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; - -kernel void kernel_leaky_relu_f32( - device const float * src0, - device float * dst, - constant float & slope, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; -} - -kernel void kernel_cpy_f16_f16( - device const half * src0, - device half * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f16_f32( - device const half * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f16( - device const float * src0, - device half * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_q8_0( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0; - - device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - const float v = src[j]; - amax = MAX(amax, fabs(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK8_0].d = d; - - for (int j = 0; j < QK8_0; ++j) { - const float x0 = src[j]*id; - - dst_data[i00/QK8_0].qs[j] = round(x0); - } - } -} - -kernel void kernel_cpy_f32_q4_0( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0; - - device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK4_0; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / -8; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK4_0].d = d; - - for (int j = 0; j < QK4_0/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK4_0/2 + j]*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); - - dst_data[i00/QK4_0].qs[j] = xi0; - dst_data[i00/QK4_0].qs[j] |= xi1 << 4; - } - } -} - -kernel void kernel_cpy_f32_q4_1( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1; - - device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float min = FLT_MAX; - float max = -FLT_MAX; - - for (int j = 0; j < QK4_1; j++) { - const float v = src[j]; - if (min > v) min = v; - if (max < v) max = v; - } - - const float d = (max - min) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK4_1].d = d; - dst_data[i00/QK4_1].m = min; - - for (int j = 0; j < QK4_1/2; ++j) { - const float x0 = (src[0 + j] - min)*id; - const float x1 = (src[QK4_1/2 + j] - min)*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); - - dst_data[i00/QK4_1].qs[j] = xi0; - dst_data[i00/QK4_1].qs[j] |= xi1 << 4; - } - } -} - -kernel void kernel_concat( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i02 < ne02) { - ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0]; - src0_ptr += ntg.x*nb00; - } else { - ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0]; - src1_ptr += ntg.x*nb10; - } - dst_ptr += ntg.x*nb0; - } -} - -//============================================ k-quants ====================================================== - -#ifndef QK_K -#define QK_K 256 -#else -static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); -#endif - -#if QK_K == 256 -#define K_SCALE_SIZE 12 -#else -#define K_SCALE_SIZE 4 -#endif - -typedef struct { - uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits - uint8_t qs[QK_K/4]; // quants - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins -} block_q2_K; -// 84 bytes / block - -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits -#if QK_K == 64 - uint8_t scales[2]; -#else - uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits -#endif - half d; // super-block scale -} block_q3_K; - -#if QK_K == 64 -typedef struct { - half d[2]; // super-block scales/mins - uint8_t scales[2]; - uint8_t qs[QK_K/2]; // 4-bit quants -} block_q4_K; -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -#endif - -#if QK_K == 64 -typedef struct { - half d; // super-block scales/mins - int8_t scales[QK_K/16]; // 8-bit block scales - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -// 176 bytes / block -#endif - -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - half d; // super-block scale -} block_q6_K; -// 210 bytes / block - -//====================================== dot products ========================= - -void kernel_mul_mv_q2_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q2_K) * nb; - -#if QK_K == 256 - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - const int is = (8*ir)/16;// 0 or 1 - - device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; - - for (int ib = ix; ib < nb; ib += 4) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; - } - - device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); - } - float dall = dh[0]; - float dmin = dh[1] * 1.f/16.f; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - - dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); - - qs += step/2; - sc += step; - dh += step/2; - } - - y4 += 4 * QK_K; - } -#else - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0...1 - - device const float * y4 = y + ix * QK_K + 8 * it; - - for (int ib = ix; ib < nb; ib += 16) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+32]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+48]; sumy[3] += yl[i+24]; - } - - device const uint8_t * sc = (device const uint8_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) - - dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4)); - - qs += step/2; - sc += step; - dh += step/2; - } - - y4 += 16 * QK_K; - } -#endif - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -[[host_name("kernel_mul_mv_q2_K_f32")]] -kernel void kernel_mul_mv_q2_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); -} - -#if QK_K == 256 -void kernel_mul_mv_q3_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - - //const uint16_t kmask1 = 0x3030; - //const uint16_t kmask2 = 0x0f0f; - - const int tid = tiisg/4; - const int ix = tiisg%4; - const int ip = tid/4; // 0 or 1 - const int il = 2*((tid%4)/2); // 0 or 2 - const int ir = tid%2; - const int n = 8; - const int l0 = n*ir; - - // One would think that the Metal compiler would figure out that ip and il can only have - // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it - // with these two tales. - // - // Possible masks for the high bit - const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 - {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 - {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 - {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 - - // Possible masks for the low 2 bits - const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; - - const ushort4 hm = mm[2*ip + il/2]; - - const int shift = 2*il; - const float v1 = il == 0 ? 4.f : 64.f; - const float v2 = 4.f * v1; - - const uint16_t s_shift1 = 4*ip; - const uint16_t s_shift2 = s_shift1 + il; - - const int q_offset = 32*ip + l0; - const int y_offset = 128*ip + 32*il + l0; - - const int step = sizeof(block_q3_K) * nb / 2; - - device const float * y1 = yy + ix*QK_K + y_offset; - - uint32_t scales32, aux32; - thread uint16_t * scales16 = (thread uint16_t *)&scales32; - thread const int8_t * scales = (thread const int8_t *)&scales32; - - float sumf1[2] = {0.f}; - float sumf2[2] = {0.f}; - for (int i = ix; i < nb; i += 4) { - - for (int l = 0; l < 8; ++l) { - yl[l+ 0] = y1[l+ 0]; - yl[l+ 8] = y1[l+16]; - yl[l+16] = y1[l+32]; - yl[l+24] = y1[l+48]; - } - - device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); - device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); - device const uint16_t * a = (device const uint16_t *)(x[i].scales); - device const half * dh = &x[i].d; - - for (int row = 0; row < 2; ++row) { - - const float d_all = (float)dh[0]; - - scales16[0] = a[4]; - scales16[1] = a[5]; - aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; - scales16[0] = a[il+0]; - scales16[1] = a[il+1]; - scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; - - float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; - for (int l = 0; l < n; l += 2) { - const int32_t qs = q[l/2]; - s1 += yl[l+0] * (qs & qm[il/2][0]); - s2 += yl[l+1] * (qs & qm[il/2][1]); - s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); - s4 += yl[l+16] * (qs & qm[il/2][2]); - s5 += yl[l+17] * (qs & qm[il/2][3]); - s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); - } - float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); - float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); - sumf1[row] += d1 * (scales[0] - 32); - sumf2[row] += d2 * (scales[2] - 32); - - s1 = s2 = s3 = s4 = s5 = s6 = 0; - for (int l = 0; l < n; l += 2) { - const int32_t qs = q[l/2+8]; - s1 += yl[l+8] * (qs & qm[il/2][0]); - s2 += yl[l+9] * (qs & qm[il/2][1]); - s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); - s4 += yl[l+24] * (qs & qm[il/2][2]); - s5 += yl[l+25] * (qs & qm[il/2][3]); - s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); - } - d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); - d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); - sumf1[row] += d1 * (scales[1] - 32); - sumf2[row] += d2 * (scales[3] - 32); - - q += step; - h += step; - a += step; - dh += step; - - } - - y1 += 4 * QK_K; - - } - - for (int row = 0; row < 2; ++row) { - const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); - sumf1[row] = simd_sum(sumf); - } - if (tiisg == 0) { - for (int row = 0; row < 2; ++row) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row]; - } - } -} -#else -void kernel_mul_mv_q3_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - const int row = 2 * r0 + sgitg; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - const int ix = tiisg/4; - const int il = 4 * (tiisg%4);// 0, 4, 8, 12 - const int iq = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - float2 sum = {0.f, 0.f}; - - for (int i = ix; i < nb; i += 8) { - - const float d_all = (float)(x[i].d); - - device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); - device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); - device const uint16_t * s = (device const uint16_t *)(x[i].scales); - device const float * y = yy + i * QK_K + il; - - const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8); - const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f; - const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f; - const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; - - for (int l = 0; l < 4; l += 2) { - const uint16_t hm = h[l/2] >> iq; - sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)) - + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) - + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) - + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); - sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)) - + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)) - + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)) - + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); - } - - } - const float sumf = sum[0] + sum[1] * 1.f/256.f; - - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; - } - -} -#endif - -[[host_name("kernel_mul_mv_q3_K_f32")]] -kernel void kernel_mul_mv_q3_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); -} - -#if QK_K == 256 -void kernel_mul_mv_q4_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int first_row = r0 * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; - float yh[16]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q4_K) * nb / 2; - - device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; - - uint16_t sc16[4]; - thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - - for (int ib = ix; ib < nb; ib += 4) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; - yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; - yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; - yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; - } - - device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq; - device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - sc16[0] = sc[0] & kmask1; - sc16[1] = sc[2] & kmask1; - sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); - sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); - - device const uint16_t * q2 = q1 + 32; - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); - acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); - acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); - acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); - acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); - acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); - acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); - acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + - (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + - (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + - (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - - q1 += step; - sc += step; - dh += step; - } - - y4 += 4 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} -#else -void kernel_mul_mv_q4_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int ix = tiisg/4; // 0...7 - const int it = tiisg%4; // 0...3 - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - const int first_row = r0 * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[8]; - float yh[8]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q4_K) * nb / 2; - - device const float * y4 = y + ix * QK_K + 8 * it; - - uint16_t sc16[4]; - - for (int ib = ix; ib < nb; ib += 8) { - - float2 sumy = {0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i] = y4[i+ 0]; sumy[0] += yl[i]; - yh[i] = y4[i+32]; sumy[1] += yh[i]; - } - - device const uint16_t * sc = (device const uint16_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - sc16[0] = sc[0] & 0x000f; - sc16[1] = sc[0] & 0x0f00; - sc16[2] = sc[0] & 0x00f0; - sc16[3] = sc[0] & 0xf000; - - float2 acc1 = {0.f, 0.f}; - float2 acc2 = {0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (qs[i/2] & 0x000F); - acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00); - acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0); - acc2[1] += yh[i+1] * (qs[i/2] & 0xF000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] + - (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) - - dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f); - - qs += step; - sc += step; - dh += step; - } - - y4 += 8 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} -#endif - -[[host_name("kernel_mul_mv_q4_K_f32")]] -kernel void kernel_mul_mv_q4_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_q5_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float sumf[2]={0.f}; - - const int step = sizeof(block_q5_K) * nb; - -#if QK_K == 256 -# - float yl[16], yh[16]; - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = tiisg/4; - const int ix = tiisg%4; - const int iq = tid/4; - const int ir = tid%4; - const int n = 8; - - const int l0 = n*ir; - const int q_offset = 32*iq + l0; - const int y_offset = 64*iq + l0; - - const uint8_t hm1 = 1u << (2*iq); - const uint8_t hm2 = hm1 << 1; - const uint8_t hm3 = hm1 << 4; - const uint8_t hm4 = hm2 << 4; - - uint16_t sc16[4]; - thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - - device const float * y1 = yy + ix*QK_K + y_offset; - - for (int i = ix; i < nb; i += 4) { - - device const uint8_t * q1 = x[i].qs + q_offset; - device const uint8_t * qh = x[i].qh + l0; - device const half * dh = &x[i].d; - device const uint16_t * a = (device const uint16_t *)x[i].scales + iq; - - device const float * y2 = y1 + 128; - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 8; ++l) { - yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; - yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; - yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; - yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; - } - - for (int row = 0; row < 2; ++row) { - - device const uint8_t * q2 = q1 + 64; - - sc16[0] = a[0] & kmask1; - sc16[1] = a[2] & kmask1; - sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); - sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); - - float4 acc1 = {0.f}; - float4 acc2 = {0.f}; - for (int l = 0; l < n; ++l) { - uint8_t h = qh[l]; - acc1[0] += yl[l+0] * (q1[l] & 0x0F); - acc1[1] += yl[l+8] * (q1[l] & 0xF0); - acc1[2] += yh[l+0] * (q2[l] & 0x0F); - acc1[3] += yh[l+8] * (q2[l] & 0xF0); - acc2[0] += h & hm1 ? yl[l+0] : 0.f; - acc2[1] += h & hm2 ? yl[l+8] : 0.f; - acc2[2] += h & hm3 ? yh[l+0] : 0.f; - acc2[3] += h & hm4 ? yh[l+8] : 0.f; - } - const float dall = dh[0]; - const float dmin = dh[1]; - sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + - sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + - sc8[4] * (acc1[2] + 16.f*acc2[2]) + - sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - - q1 += step; - qh += step; - dh += step/2; - a += step/2; - - } - - y1 += 4 * QK_K; - - } -#else - float yl[8], yh[8]; - - const int il = 4 * (tiisg/8); // 0, 4, 8, 12 - const int ix = tiisg%8; - const int iq = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - device const float * y = yy + ix*QK_K + il; - - for (int i = ix; i < nb; i += 8) { - - for (int l = 0; l < 4; ++l) { - yl[l+0] = y[l+ 0]; - yl[l+4] = y[l+16]; - yh[l+0] = y[l+32]; - yh[l+4] = y[l+48]; - } - - device const half * dh = &x[i].d; - device const uint8_t * q = x[i].qs + il; - device const uint8_t * h = x[i].qh + in; - device const int8_t * s = x[i].scales; - - for (int row = 0; row < 2; ++row) { - - const float d = dh[0]; - - float2 acc = {0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - const uint8_t hl = h[l] >> iq; - acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16)) - + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16)); - acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256)) - + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256)); - } - sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]); - - q += step; - h += step; - s += step; - dh += step/2; - - } - - y += 8 * QK_K; - } -#endif - - for (int row = 0; row < 2; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; - } - } -} - -[[host_name("kernel_mul_mv_q5_K_f32")]] -kernel void kernel_mul_mv_q5_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_q6_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const uint8_t kmask1 = 0x03; - const uint8_t kmask2 = 0x0C; - const uint8_t kmask3 = 0x30; - const uint8_t kmask4 = 0xC0; - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int im = tgpig.z; - - const int row = 2 * r0 + sgitg; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float sumf = 0; - -#if QK_K == 256 - const int tid = tiisg/2; - const int ix = tiisg%2; - const int ip = tid/8; // 0 or 1 - const int il = tid%8; - const int n = 4; - const int l0 = n*il; - const int is = 8*ip + l0/16; - - const int y_offset = 128*ip + l0; - const int q_offset_l = 64*ip + l0; - const int q_offset_h = 32*ip + l0; - - for (int i = ix; i < nb; i += 2) { - - device const uint8_t * q1 = x[i].ql + q_offset_l; - device const uint8_t * q2 = q1 + 32; - device const uint8_t * qh = x[i].qh + q_offset_h; - device const int8_t * sc = x[i].scales + is; - - device const float * y = yy + i * QK_K + y_offset; - - const float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - - sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); - - } - -#else - const int ix = tiisg/4; - const int il = 4*(tiisg%4); - - for (int i = ix; i < nb; i += 8) { - device const float * y = yy + i * QK_K + il; - device const uint8_t * ql = x[i].ql + il; - device const uint8_t * qh = x[i].qh + il; - device const int8_t * s = x[i].scales; - - const float d = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32); - sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); - } - -#endif - - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; - } -} - -[[host_name("kernel_mul_mv_q6_K_f32")]] -kernel void kernel_mul_mv_q6_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); -} - -//============================= templates and their specializations ============================= - -// NOTE: this is not dequantizing - we are simply fitting the template -template -void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { - float4x4 temp = *(((device float4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} - -template -void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { - half4x4 temp = *(((device half4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} - -template -void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 1); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float md = -8.h * xb->d; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; - - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; - reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; - } -} - -template -void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 2); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float m = xb->m; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; - - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; - reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; - } -} - -template -void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 3); - const float d = xb->d; - const float md = -16.h * xb->d; - const ushort mask = il ? 0x00F0 : 0x000F; - - const uint32_t qh = *((device const uint32_t *)xb->qh); - - const int x_mv = il ? 4 : 0; - - const int gh_mv = il ? 12 : 0; - const int gh_bk = il ? 0 : 4; - - for (int i = 0; i < 8; i++) { - // extract the 5-th bits for x0 and x1 - const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; - const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; - - // combine the 4-bits from qs with the 5th bit - const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); - const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); - - reg[i/2][2*(i%2)+0] = d * x0 + md; - reg[i/2][2*(i%2)+1] = d * x1 + md; - } -} - -template -void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 4); - const float d = xb->d; - const float m = xb->m; - const ushort mask = il ? 0x00F0 : 0x000F; - - const uint32_t qh = *((device const uint32_t *)xb->qh); - - const int x_mv = il ? 4 : 0; - - const int gh_mv = il ? 12 : 0; - const int gh_bk = il ? 0 : 4; - - for (int i = 0; i < 8; i++) { - // extract the 5-th bits for x0 and x1 - const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; - const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; - - // combine the 4-bits from qs with the 5th bit - const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); - const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); - - reg[i/2][2*(i%2)+0] = d * x0 + m; - reg[i/2][2*(i%2)+1] = d * x1 + m; - } -} - -template -void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { - device const int8_t * qs = ((device const int8_t *)xb->qs); - const half d = xb->d; - - for (int i = 0; i < 16; i++) { - reg[i/4][i%4] = (qs[i + 16*il] * d); - } -} - -template -void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { - const float d = xb->d; - const float min = xb->dmin; - device const uint8_t * q = (device const uint8_t *)xb->qs; - float dl, ml; - uint8_t sc = xb->scales[il]; - -#if QK_K == 256 - q = q + 32*(il/8) + 16*(il&1); - il = (il/2)%4; -#endif - half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } -} - -template -void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { - const half d_all = xb->d; - device const uint8_t * q = (device const uint8_t *)xb->qs; - device const uint8_t * h = (device const uint8_t *)xb->hmask; - device const int8_t * scales = (device const int8_t *)xb->scales; - -#if QK_K == 256 - q = q + 32 * (il/8) + 16 * (il&1); - h = h + 16 * (il&1); - uint8_t m = 1 << (il/2); - uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ - ((il/4)>0 ? 12 : 3); - uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; - uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; - int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) - : (scale_2&kmask2) | ((scale_1&kmask1) << 4); - half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h); - const half ml = 4.h * dl; - - il = (il/2) & 3; - const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - dl *= coef; - - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); - } -#else - float kcoef = il&1 ? 1.f/16.f : 1.f; - uint16_t kmask = il&1 ? 0xF0 : 0x0F; - float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8); - float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - uint8_t m = 1<<(il*2); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef)); - } -#endif -} - -static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { - return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} - : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; -} - -template -void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { - device const uchar * q = xb->qs; - -#if QK_K == 256 - short is = (il/4) * 2; - q = q + (il/4) * 32 + 16 * (il&1); - il = il & 3; - const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const float d = il < 2 ? xb->d : xb->d / 16.h; - const float min = xb->dmin; - const float dl = d * sc[0]; - const float ml = min * sc[1]; -#else - q = q + 16 * (il&1); - device const uint8_t * s = xb->scales; - device const half2 * dh = (device const half2 *)xb->d; - const float2 d = (float2)dh[0]; - const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; - const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4); -#endif - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } -} - -template -void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { - device const uint8_t * q = xb->qs; - device const uint8_t * qh = xb->qh; - -#if QK_K == 256 - short is = (il/4) * 2; - q = q + 32 * (il/4) + 16 * (il&1); - qh = qh + 16 * (il&1); - uint8_t ul = 1 << (il/2); - il = il & 3; - const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const float d = il < 2 ? xb->d : xb->d / 16.h; - const float min = xb->dmin; - const float dl = d * sc[0]; - const float ml = min * sc[1]; - - const ushort mask = il<2 ? 0x0F : 0xF0; - const float qh_val = il<2 ? 16.f : 256.f; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; - } -#else - q = q + 16 * (il&1); - device const int8_t * s = xb->scales; - const float dl = xb->d * s[il]; - uint8_t m = 1<<(il*2); - const float coef = il<2 ? 1.f : 1.f/16.f; - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef)); - } -#endif -} - -template -void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { - const half d_all = xb->d; - device const uint8_t * ql = (device const uint8_t *)xb->ql; - device const uint8_t * qh = (device const uint8_t *)xb->qh; - device const int8_t * scales = (device const int8_t *)xb->scales; - -#if QK_K == 256 - ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); - qh = qh + 32*(il/8) + 16*(il&1); - half sc = scales[(il%2) + 2 * ((il/2))]; - il = (il/2) & 3; -#else - ql = ql + 16 * (il&1); - half sc = scales[il]; -#endif - const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const half coef = il>1 ? 1.f/16.h : 1.h; - const half ml = d_all * sc * 32.h; - const half dl = d_all * sc * coef; - for (int i = 0; i < 16; ++i) { - const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) - : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); - reg[i/4][i%4] = dl * q - ml; - } -} - -template -kernel void kernel_get_rows( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - //const int64_t i = tgpig; - //const int64_t r = ((device int32_t *) src1)[i]; - - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { - float4x4 temp; - dequantize_func( - ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; - } -} - -kernel void kernel_get_rows_f32( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_f16( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_i32( - device const void * src0, - device const char * src1, - device int32_t * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - - -#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A -#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B -#define BLOCK_SIZE_K 32 -#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A -#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B -#define THREAD_PER_BLOCK 128 -#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers -#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers -#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 -#define SG_MAT_ROW 8 - -// each block_q contains 16*nl weights -template -void kernel_mul_mm_impl(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); - - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - const uint im = tgpig.z; - - // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; - - // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - - simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; - for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); - } - - short il = (tiitg % THREAD_PER_ROW); - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); - ushort offset1 = il/nl; - - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * im - + nb11 * (r1 * BLOCK_SIZE_N + thread_col) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); - - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { - // load data and store to threadgroup memory - half4x4 temp_a; - dequantize_func(x, il, temp_a); - threadgroup_barrier(mem_flags::mem_threadgroup); - - #pragma unroll(16) - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; - } - - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); - - il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; - y += BLOCK_SIZE_K; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); - - #pragma unroll(4) - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { - #pragma unroll(4) - for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); - } - simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(2) - for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); - } - - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - - #pragma unroll(8) - for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); - } - } - } - - if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); - } - } else { - // block is smaller than 64x32, we should avoid writing data outside of the matrix - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; - if (sgitg == 0) { - for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); - } - } - } - } -} - -// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids -template -void kernel_mul_mm_id_impl( - device const uchar * src0, - device const uchar * src1, - thread short * src1ids, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - int64_t ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); - - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - const uint im = tgpig.z; - - if (r1 * BLOCK_SIZE_N >= ne1) return; - - // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; - - // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - - simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; - for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); - } - - short il = (tiitg % THREAD_PER_ROW); - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); - ushort offset1 = il/nl; - - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * im - + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col] - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); - - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { - // load data and store to threadgroup memory - half4x4 temp_a; - dequantize_func(x, il, temp_a); - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; - } - - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); - - il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; - y += BLOCK_SIZE_K; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); - - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { - for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); - } - simdgroup_barrier(mem_flags::mem_none); - for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); - } - - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - - for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); - } - } - } - - { - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0; - if (sgitg == 0) { - for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); - } - } - } - } -} - -template -kernel void kernel_mul_mm(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mm_impl( - src0, - src1, - dst, - ne00, - ne02, - nb01, - nb02, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - shared_memory, - tgpig, - tiitg, - sgitg); -} - -template -kernel void kernel_mul_mm_id( - device const uchar * ids, - device const uchar * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const uchar * src00, - device const uchar * src01, - device const uchar * src02, - device const uchar * src03, - device const uchar * src04, - device const uchar * src05, - device const uchar * src06, - device const uchar * src07, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - // expert id - const int32_t id = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - // row indices of src1 for expert id - int64_t _ne1 = 0; - short src1ids[512]; - - for (int64_t i1 = 0; i1 < ne1; i1++) { - if (((device int32_t *) (ids + i1*nbi1))[idx] == id) { - src1ids[_ne1++] = i1; - } - } - - kernel_mul_mm_id_impl( - src0s[id], - src1, - src1ids, - dst, - ne00, - ne02, - nb01, - nb02, - ne12, - nb10, - nb11, - nb12, - ne0, - _ne1, - r2, - r3, - shared_memory, - tgpig, - tiitg, - sgitg); -} - -#if QK_K == 256 -#define QK_NL 16 -#else -#define QK_NL 4 -#endif - -// -// get rows -// - -typedef void (get_rows_t)( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3, uint, uint3); - -//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; -//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; - -// -// matrix-matrix multiplication -// - -typedef void (mat_mm_t)( - device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar *, - uint3, uint, uint); - -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; - -// -// indirect matrix-matrix multiplication -// - -typedef void (mat_mm_id_t)( - device const uchar * ids, - device const uchar * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const uchar * src00, - device const uchar * src01, - device const uchar * src02, - device const uchar * src03, - device const uchar * src04, - device const uchar * src05, - device const uchar * src06, - device const uchar * src07, - threadgroup uchar *, - uint3, uint, uint); - -template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; - -// -// matrix-vector multiplication -// - -[[host_name("kernel_mul_mv_id_f32_f32")]] -kernel void kernel_mul_mv_id_f32_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_f32_f32_impl( - src0[id], - src1 + bid*nb11, - dst + bid*ne0, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - ne10, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg); -} - -[[host_name("kernel_mul_mv_id_f16_f32")]] -kernel void kernel_mul_mv_id_f16_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_f16_f32_impl( - src0[id], - src1 + bid*nb11, - dst + bid*ne0, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - ne10, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg); -} - -[[host_name("kernel_mul_mv_id_q8_0_f32")]] -kernel void kernel_mul_mv_id_q8_0_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q8_0_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q4_0_f32")]] -kernel void kernel_mul_mv_id_q4_0_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q4_1_f32")]] -kernel void kernel_mul_mv_id_q4_1_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q5_0_f32")]] -kernel void kernel_mul_mv_id_q5_0_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q5_1_f32")]] -kernel void kernel_mul_mv_id_q5_1_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - mul_vec_q_n_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q2_K_f32")]] -kernel void kernel_mul_mv_id_q2_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q2_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q3_K_f32")]] -kernel void kernel_mul_mv_id_q3_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q3_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q4_K_f32")]] -kernel void kernel_mul_mv_id_q4_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q4_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q5_K_f32")]] -kernel void kernel_mul_mv_id_q5_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q5_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} - -[[host_name("kernel_mul_mv_id_q6_K_f32")]] -kernel void kernel_mul_mv_id_q6_K_f32( - device const char * ids, - device const char * src1, - device float * dst, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - constant uint & r2, - constant uint & r3, - constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - - const int64_t bid = tgpig.z/(ne12*ne13); - - tgpig.z = tgpig.z%(ne12*ne13); - - const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; - - kernel_mul_mv_q6_K_f32_impl( - src0[id], - (device const float *) (src1 + bid*nb11), - dst + bid*ne0, - ne00, - ne01, - ne02, - ne10, - ne12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg, - sgitg); -} diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 7fea8fd..d752ed4 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -5,17 +5,12 @@ use tauri::{ http::{HttpRange, ResponseBuilder}, - Manager, State, Window, + AppHandle, Manager, State, Window, }; use tauri_plugin_sql::{Migration, MigrationKind}; use std::{ - cmp::min, - env, - io::{Read, Seek, SeekFrom}, - path::PathBuf, - str::FromStr, - sync::{Arc, Mutex}, + cmp::min, env, io::{Read, Seek, SeekFrom}, path::PathBuf, str::FromStr, sync::{Arc, Mutex} }; use crossbeam_channel::{unbounded, Sender}; @@ -28,66 +23,132 @@ use module::{ deleter::NoteDeleter, device::{self, Device}, downloader::{ - fugumt::FugumtModelDownloader, honyaku13b::Honyaku13BModelDownloader, - vosk::VoskModelDownloader, whisper::WhisperModelDownloader, + model_dir::ModelDirDownloader, sbv2::StyleBertVits2ModelDownloader, + sbv2_voice::StyleBertVits2VoiceModelDownloader, vosk::VoskModelDownloader, + whisper::WhisperModelDownloader, }, + model_type_sbv2::ModelTypeStyleBertVits2, model_type_vosk::ModelTypeVosk, model_type_whisper::ModelTypeWhisper, permissions, record::Record, record_desktop::RecordDesktop, screenshot::{self, AppWindow}, + synthesizer::{self, Synthesizer}, transcription::{TraceCompletion, Transcription}, transcription_amivoice::TranscriptionAmivoice, transcription_online::TranscriptionOnline, + translation_en::TranslationEn, translation_ja::TranslationJa, translation_ja_high::TranslationJaHigh, }; struct RecordState(Arc>>>); +struct SynthesizeState(Arc>>); const BUNDLE_IDENTIFIER: &str = "blog.aota.Lycoris"; #[tauri::command] -fn delete_note_command(window: Window, note_id: u64) { +fn list_synthesize_models_command(app_handle: AppHandle) -> Vec { + synthesizer::list_models(app_handle) +} + +#[tauri::command] +async fn synthesize_init_command( + state: State<'_, SynthesizeState>, + app_handle: AppHandle, + model: String, +) -> Result { + let state_clone = state.0.clone(); + let synthesizer = Synthesizer::new(app_handle, model); + let mut lock = state_clone.lock().unwrap(); + *lock = Some(synthesizer); + Ok(true) +} + +#[tauri::command] +fn synthesize_finalize_command(state: State<'_, SynthesizeState>) -> bool { + let mut lock = state.0.lock().unwrap(); + lock.take(); + true +} + +#[tauri::command] +async fn synthesize_command( + state: State<'_, SynthesizeState>, + text: String, + sdp_ratio: f32, + length_scale: f32, +) -> Result, String> { + let mut lock = state.0.lock().unwrap(); + let synthesizer = lock.as_mut().unwrap(); + synthesizer.synthesize(text, sdp_ratio, length_scale) +} + +#[tauri::command] +fn delete_note_command(app_handle: AppHandle, note_id: u64) { std::thread::spawn(move || { - let deleter = NoteDeleter::new(window.app_handle().clone()); + let deleter = NoteDeleter::new(app_handle); deleter.delete(note_id) }); } #[tauri::command] -fn download_whisper_model_command(window: Window, model: String) { +fn download_whisper_model_command(app_handle: AppHandle, model: String) { std::thread::spawn(move || { - let dl = WhisperModelDownloader::new(window.app_handle().clone()); + let dl = WhisperModelDownloader::new(app_handle); dl.download(ModelTypeWhisper::from_str(&model).unwrap()) }); } #[tauri::command] -fn download_vosk_model_command(window: Window, model: String) { +fn download_vosk_model_command(app_handle: AppHandle, model: String) { std::thread::spawn(move || { - let dl = VoskModelDownloader::new(window.app_handle().clone()); + let dl = VoskModelDownloader::new(app_handle); dl.download(ModelTypeVosk::from_str(&model).unwrap()) }); } #[tauri::command] -fn download_fugumt_model_command(window: Window) { +fn download_fugumt_enja_model_command(app_handle: AppHandle) { std::thread::spawn(move || { - let dl = FugumtModelDownloader::new(window.app_handle().clone()); - dl.download() + let dl = ModelDirDownloader::new(app_handle); + dl.download("fugumt-en-ja", "downloadFugumtEnJaProgress") }); } #[tauri::command] -fn download_honyaku13b_model_command(window: Window) { +fn download_fugumt_jaen_model_command(app_handle: AppHandle) { std::thread::spawn(move || { - let dl = Honyaku13BModelDownloader::new(window.app_handle().clone()); + let dl = ModelDirDownloader::new(app_handle); + dl.download("fugumt-ja-en", "downloadFugumtJaEnProgress") + }); +} + +#[tauri::command] +fn download_honyaku13b_model_command(app_handle: AppHandle) { + std::thread::spawn(move || { + let dl = ModelDirDownloader::new(app_handle); + dl.download("honyaku-13b", "downloadHonyaku13BProgress") + }); +} + +#[tauri::command] +fn download_sbv2_command(app_handle: AppHandle) { + std::thread::spawn(move || { + let dl = StyleBertVits2ModelDownloader::new(app_handle); dl.download() }); } +#[tauri::command] +fn download_sbv2_model_command(app_handle: AppHandle, model: String) { + std::thread::spawn(move || { + let dl = StyleBertVits2VoiceModelDownloader::new(app_handle); + dl.download(ModelTypeStyleBertVits2::from_str(&model).unwrap()) + }); +} + #[tauri::command] fn list_devices_command() -> Vec { device::list_devices() @@ -104,8 +165,9 @@ fn list_app_windows_command(app_name: String) -> Vec { } #[tauri::command] -fn screenshot_command(window: Window, window_id: u32, note_id: u64) -> bool { - screenshot::screenshot(window_id, note_id, window.app_handle().clone()) +async fn screenshot_command(app_handle: AppHandle, window_id: u32, note_id: u64) -> Result { + let result = screenshot::screenshot(window_id, note_id, app_handle); + Ok(result) } #[tauri::command] @@ -124,9 +186,9 @@ fn has_microphone_permission_command(window: Window) -> bool { } #[tauri::command] -fn execute_action_command(window: Window, note_id: u64) { +fn execute_action_command(app_handle: AppHandle, note_id: u64) { std::thread::spawn(move || { - if action::initialize_action(window.app_handle().clone(), note_id) { + if action::initialize_action(app_handle, note_id) { let mut lock = action::SINGLETON_INSTANCE.lock().unwrap(); if let Some(singleton) = lock.as_mut() { singleton.execute(); @@ -141,7 +203,7 @@ fn execute_action_command(window: Window, note_id: u64) { #[tauri::command] fn start_command( state: State<'_, RecordState>, - window: Window, + app_handle: AppHandle, device_label: String, speaker_language: String, transcription_accuracy: String, @@ -153,7 +215,7 @@ fn start_command( *lock = Some(stop_record_tx); std::thread::spawn(move || { if device_type == "microphone" { - let record = Record::new(window.app_handle().clone()); + let record = Record::new(app_handle); record.start( device_label, speaker_language, @@ -162,7 +224,7 @@ fn start_command( stop_record_rx, ); } else if device_type == "desktop" { - let record_desktop = RecordDesktop::new(window.app_handle().clone()); + let record_desktop = RecordDesktop::new(app_handle); record_desktop.start( speaker_language, transcription_accuracy, @@ -171,8 +233,8 @@ fn start_command( None, ); } else { - let record = Record::new(window.app_handle().clone()); - let record_desktop = RecordDesktop::new(window.app_handle().clone()); + let record = Record::new(app_handle.clone()); + let record_desktop = RecordDesktop::new(app_handle); let (stop_record_clone_tx, stop_record_clone_rx) = unbounded(); let speaker_language_clone = speaker_language.clone(); @@ -209,7 +271,7 @@ fn stop_command(state: State<'_, RecordState>) { #[tauri::command] fn start_trace_command( state: State<'_, RecordState>, - window: tauri::Window, + app_handle: AppHandle, speaker_language: String, transcription_accuracy: String, note_id: u64, @@ -221,30 +283,31 @@ fn start_trace_command( std::thread::spawn(move || { if transcription_accuracy.starts_with("online-transcript") { let mut transcription_online = TranscriptionOnline::new( - window.app_handle(), + app_handle, transcription_accuracy, speaker_language, note_id, ); transcription_online.start(stop_convert_rx, true); } else if transcription_accuracy.starts_with("online-amivoice") { - let mut transcription_amivoice = - TranscriptionAmivoice::new(window.app_handle(), note_id); + let mut transcription_amivoice = TranscriptionAmivoice::new(app_handle, note_id); transcription_amivoice.start(stop_convert_rx, true); } else if transcription_accuracy.starts_with("online-chat") { - let mut chat_online = ChatOnline::new(window.app_handle(), speaker_language, note_id); + let mut chat_online = ChatOnline::new(app_handle, speaker_language, note_id); chat_online.start(stop_convert_rx, true); } else if transcription_accuracy.starts_with("fugumt-en-ja") { - let mut translation_ja = - TranslationJa::new(window.app_handle(), speaker_language, note_id); + let mut translation_ja = TranslationJa::new(app_handle, speaker_language, note_id); translation_ja.start(stop_convert_rx, true); - } else if transcription_accuracy.starts_with("honyaku13b-q4-0") { + } else if transcription_accuracy.starts_with("fugumt-ja-en") { + let mut translation_en = TranslationEn::new(app_handle, note_id); + translation_en.start(stop_convert_rx, true); + } else if transcription_accuracy.starts_with("honyaku-13b") { let mut translation_ja_high = - TranslationJaHigh::new(window.app_handle(), speaker_language, note_id); + TranslationJaHigh::new(app_handle, speaker_language, note_id); translation_ja_high.start(stop_convert_rx, true); } else { let mut transcription = Transcription::new( - window.app_handle(), + app_handle, transcription_accuracy, speaker_language, note_id, @@ -255,40 +318,18 @@ fn start_trace_command( } #[tauri::command] -fn stop_trace_command(state: State<'_, RecordState>, window: tauri::Window) { +fn stop_trace_command(state: State<'_, RecordState>, app_handle: AppHandle) { let mut lock = state.0.lock().unwrap(); if let Some(stop_convert_tx) = lock.take() { stop_convert_tx.send(()).unwrap_or_else(|_| { - window - .app_handle() + app_handle .emit_all("traceCompletion", TraceCompletion {}) .unwrap(); }) } } -fn set_whisper_metal_lib_path(relative_path: &str) { - if let Ok(exe_path) = env::current_exe() { - if let Some(exe_dir) = exe_path.parent() { - let absolute_path = exe_dir.join(relative_path); - if let Some(absolute_path_str) = absolute_path.to_str() { - println!("Setting GGML_METAL_PATH_RESOURCES to {}", absolute_path_str); - env::set_var("GGML_METAL_PATH_RESOURCES", absolute_path_str); - } - } else { - eprintln!("GGML_METAL_PATH_RESOURCES cloud not be set: Failed to get the executable directory."); - } - } else { - eprintln!("GGML_METAL_PATH_RESOURCES cloud not be set: Failed to get the executable path."); - } -} - fn main() { - #[cfg(not(debug_assertions))] - set_whisper_metal_lib_path("../Resources/resources/whisper"); - #[cfg(debug_assertions)] - set_whisper_metal_lib_path("../../resources/whisper"); - tauri::Builder::default() .register_uri_scheme_protocol("stream", move |_app, request| { let raw_path = request.uri().replace("stream://localhost", ""); @@ -347,12 +388,20 @@ fn main() { .build(), ) .manage(RecordState(Default::default())) + .manage(SynthesizeState(Default::default())) .invoke_handler(tauri::generate_handler![ + list_synthesize_models_command, + synthesize_init_command, + synthesize_finalize_command, + synthesize_command, delete_note_command, download_whisper_model_command, download_vosk_model_command, - download_fugumt_model_command, + download_fugumt_enja_model_command, + download_fugumt_jaen_model_command, download_honyaku13b_model_command, + download_sbv2_command, + download_sbv2_model_command, list_devices_command, list_apps_command, list_app_windows_command, diff --git a/src-tauri/src/module/action.rs b/src-tauri/src/module/action.rs index 1af89bb..a3fcfc3 100644 --- a/src-tauri/src/module/action.rs +++ b/src-tauri/src/module/action.rs @@ -8,8 +8,10 @@ use serde_json::{json, Value}; use tauri::{AppHandle, Manager}; use super::sqlite::{Content, Sqlite}; +use tokio::runtime::Runtime; pub struct Action { + runtime: Runtime, app_handle: AppHandle, sqlite: Sqlite, note_id: u64, @@ -19,12 +21,14 @@ pub struct Action { impl Action { pub fn new(app_handle: AppHandle, note_id: u64) -> Self { + let runtime = Runtime::new().expect("Failed to create Tokio runtime"); let sqlite = Sqlite::new(); let token = sqlite.select_whisper_token().unwrap(); let model = sqlite .select_ai_model() .unwrap_or_else(|_| "gpt-4o-mini".to_string()); Self { + runtime, app_handle, sqlite, note_id, @@ -33,7 +37,6 @@ impl Action { } } - #[tokio::main] async fn request_gpt( model: String, question: String, @@ -90,17 +93,25 @@ impl Action { .push_str(&format!(":::{}\n{}\n:::\n", current_type, current_content)); current_content.clear(); } - prompt.push_str(&format!( - ":::assistant\n[query]\n{}\n[answer]\n{}\n:::\n", - content.content, content.content_2 - )); + if content.action_type == "suggest" { + prompt.push_str(&format!( + ":::assistant\n[query]\n次の発言者のための3つの発話サジェストとその理由を生成してください。\n[answer] {}\n{}\n:::\n", + content.content, content.content_2 + )); + } else { + prompt.push_str(&format!( + ":::assistant\n[query]\n{}\n[answer]\n{}\n:::\n", + content.content, content.content_2 + )); + } } "speech" | _ => { let speech_type = if content.speech_type == "speech" { "transcription" } else if content.speech_type == "memo" { "note" - } else { // "screenshot" + } else { + // "screenshot" "note" }; if speech_type != current_type && !current_content.is_empty() { @@ -160,6 +171,176 @@ impl Action { Ok(response_text) } + async fn request_gpt_suggest( + contents: Vec, + token: String, + ) -> Result> { + let url = "https://api.openai.com/v1/chat/completions"; + let temperature = 0.7; + + let client = Client::new(); + + let mut headers = HeaderMap::new(); + headers.insert( + AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {}", token))?, + ); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + + let mut messages: Vec = Vec::new(); + let mut prompt = "あなたは高度な会話分析・発話提案AIアシスタントです。提供される情報を分析し、次の発言者に対して適切な発話サジェストを生成することがあなたの役割です。以下の手順に従って処理を行ってください: + +1. 情報の分析: + a) 文字起こし (:::transcription で囲まれた部分):直近の会話の内容と流れを詳細に把握します。 + b) メモ (:::note で囲まれた部分):会話の背景や補足情報として扱います。 + c) 過去のAIとのQ&A (:::assistant で囲まれた部分):関連する追加情報として考慮します。ただし直近の発話サジェストは、今回の発話サジェストが同一性を持たないように考慮します。 + +2. 会話の状況把握: + - 誰が最後に発言したか、どのような内容だったかを特定します。 + - 発話サジェストを受ける人が直前で聞き手だったことを前提とします。 + +3. 発話サジェストの生成: + 会話の流れ、文脈、および背景情報を考慮し、以下の3種類の発言提案を生成します: + a) 中立的な発言:会話を自然に進行させる発言 + b) ポジティブな発言:質問、共感、あるいは会話を前向きな方向に導く発言。 + c) ネガティブな発言:懸念や問題点を指摘する発言 + +4. 各発言提案の理由付け: + それぞれの発言提案について、なぜその発言が適切か、どのような効果が期待できるかを簡潔に説明します。理由付けの詳細さは、必要に応じて調整します。 + +5. 発話サジェストの調整: + - 会話の雰囲気や目的に応じて、各提案の内容や調子を調整します。 + - 文化的背景や社会的文脈を考慮し、適切な表現を選択します。 + +以下に提供される情報を上記の手順に従って分析し、中立的、ポジティブ(質問・共感を含む)、ネガティブな3つの発話サジェストとその理由を生成してください。各サジェストは自然で、会話の流れに沿ったものにしてください: + +".to_string(); + let mut current_type = String::new(); + let mut current_content = String::new(); + + for content in contents.iter() { + match content.speech_type.as_str() { + "action" => { + if !current_content.is_empty() { + prompt + .push_str(&format!(":::{}\n{}\n:::\n", current_type, current_content)); + current_content.clear(); + } + if content.action_type == "suggest" { + prompt.push_str(&format!( + ":::assistant\n[query]\n次の発言者のための3つの発話サジェストとその理由を生成してください。\n[answer] {}\n{}\n:::\n", + content.content, content.content_2 + )); + } else { + prompt.push_str(&format!( + ":::assistant\n[query]\n{}\n[answer]\n{}\n:::\n", + content.content, content.content_2 + )); + } + } + "speech" | _ => { + let speech_type = if content.speech_type == "speech" { + "transcription" + } else if content.speech_type == "memo" { + "note" + } else { + // "screenshot" + "note" + }; + if speech_type != current_type && !current_content.is_empty() { + prompt + .push_str(&format!(":::{}\n{}\n:::\n", current_type, current_content)); + current_content.clear(); + } + current_type = speech_type.to_string(); + current_content.push_str(&content.content); + current_content.push('\n'); + } + } + } + + if !current_content.is_empty() { + prompt.push_str(&format!(":::{}\n{}\n:::\n", current_type, current_content)); + } + + messages.push(json!({ + "role": "system", + "content": prompt + })); + messages.push(json!({ + "role": "user", + "content": "上記の情報を基に、次の発言者のための3つの発話サジェストとその理由を生成してください。" + })); + + // for debugging + // println!("messages: {:?}", messages); + + let response_format = json!({ + "type": "json_schema", + "json_schema": { + "name": "generate_speech_suggestions", + "description": "提供されたコンテキストに基づいて、ニュートラル、ポジティブ、ネガティブな発言の提案とその理由を生成します。", + "strict": true, + "schema": { + "type": "object", + "properties": { + "neutral": { "$ref": "#/$defs/suggestion" }, + "positive": { "$ref": "#/$defs/suggestion" }, + "negative": { "$ref": "#/$defs/suggestion" } + }, + "required": ["neutral", "positive", "negative"], + "additionalProperties": false, + "$defs": { + "suggestion": { + "type": "object", + "description": "発言の提案とその理由", + "properties": { + "content": { + "type": "string", + "description": "提案される発言内容" + }, + "reason": { + "type": "string", + "description": "その発言を提案する理由" + } + }, + "required": ["content", "reason"], + "additionalProperties": false + } + } + } + } + }); + + let post_body = json!({ + "model": "gpt-4o-2024-08-06", + "temperature": temperature, + "messages": messages, + "response_format": response_format + }); + + let response = client + .post(url) + .headers(headers) + .json(&post_body) + .send() + .await?; + + let status = response.status(); + let json_response: Value = response.json().await?; + + let response_text = if status == 200 { + json_response["choices"][0]["message"]["content"] + .as_str() + .unwrap_or("choices[0].message.content field not found") + .to_string() + } else { + json_response.to_string() + }; + + Ok(response_text) + } + pub fn execute(&mut self) { if self.token == "" { println!("whisper token is empty, so skipping..."); @@ -168,62 +349,112 @@ impl Action { let mut is_executing = IS_EXECUTING.lock().unwrap(); *is_executing = true; - while let Ok(action) = self.sqlite.select_first_unexecuted_action(self.note_id) { - match self - .sqlite - .select_has_no_permission_of_execute_action(self.note_id, action.id) - { - Ok(permissions) => { - if permissions.is_empty() || permissions.iter().any(|p| p.model == "whisper") { - match self.sqlite.select_contents_by(self.note_id, action.id) { - Ok(contents) => { - match Self::request_gpt( - self.model.clone(), - action.content, - contents, - self.token.clone(), - ) { - Ok(answer) => { - match self - .sqlite - .update_action_content_2(action.id, answer.clone()) + self.runtime.block_on(async { + while let Ok(action) = self.sqlite.select_first_unexecuted_action(self.note_id) { + match self + .sqlite + .select_has_no_permission_of_execute_action(self.note_id, action.id) + { + Ok(permissions) => { + if permissions.is_empty() + || permissions.iter().any(|p| p.model == "whisper") + { + match self.sqlite.select_contents_by(self.note_id, action.id) { + Ok(contents) => match action.action_type.as_str() { + "chat" => { + match Self::request_gpt( + self.model.clone(), + action.content, + contents, + self.token.clone(), + ) + .await { - Ok(result) => { - let _ = self - .app_handle - .emit_all("actionExecuted", result); + Ok(answer) => { + match self.sqlite.update_action_content_2( + action.id, + answer.clone(), + ) { + Ok(result) => { + let _ = self + .app_handle + .emit_all("actionExecuted", result); + } + Err(e) => { + println!( + "Error updating action content_2: {:?}", + e + ); + break; + } + } } - Err(e) => { + Err(_) => { println!( - "Error updating action content_2: {:?}", - e + "gpt api is temporarily failed, so skipping..." ); break; } } } - Err(_) => { - println!("gpt api is temporarily failed, so skipping..."); + "suggest" => { + match Self::request_gpt_suggest( + contents, + self.token.clone(), + ) + .await + { + Ok(answer) => { + match self.sqlite.update_action_content_2( + action.id, + answer.clone(), + ) { + Ok(result) => { + let _ = self + .app_handle + .emit_all("actionExecuted", result); + } + Err(e) => { + println!( + "Error updating action content_2: {:?}", + e + ); + break; + } + } + } + Err(_) => { + println!( + "gpt api is temporarily failed, so skipping..." + ); + break; + } + } + } + &_ => { + println!("Unsupported action type, so skipping..."); break; } + }, + Err(e) => { + println!("Error selecting contents: {:?}", e); + break; } } - Err(e) => { - println!("Error selecting contents: {:?}", e); - break; - } + } else { + println!( + "has_no_permission_of_execute_action is false, so skipping..." + ); + break; } - } else { - println!("has_no_permission_of_execute_action is false, so skipping..."); + } + Err(e) => { + println!("Error checking permissions: {:?}", e); break; } } - Err(e) => { - println!("Error checking permissions: {:?}", e); - break; - } } - } + }); *is_executing = false; } diff --git a/src-tauri/src/module/device.rs b/src-tauri/src/module/device.rs index 3f79f3a..2ea95c7 100644 --- a/src-tauri/src/module/device.rs +++ b/src-tauri/src/module/device.rs @@ -12,7 +12,7 @@ pub fn list_devices() -> Vec { .input_devices() .unwrap() .filter_map(|device| { - if device.name().is_ok() && device.name().unwrap().contains("ZoomAudioDevice") { + if device.name().is_ok() && (device.name().unwrap().contains("ZoomAudioDevice") || device.name().unwrap().contains("Microsoft Teams Audio")) { None } else { Some(Device { diff --git a/src-tauri/src/module/downloader/mod.rs b/src-tauri/src/module/downloader/mod.rs index 618e3ca..3a42ae9 100644 --- a/src-tauri/src/module/downloader/mod.rs +++ b/src-tauri/src/module/downloader/mod.rs @@ -1,4 +1,5 @@ -pub mod whisper; +pub mod model_dir; +pub mod sbv2; +pub mod sbv2_voice; pub mod vosk; -pub mod fugumt; -pub mod honyaku13b; \ No newline at end of file +pub mod whisper; diff --git a/src-tauri/src/module/downloader/honyaku13b.rs b/src-tauri/src/module/downloader/model_dir.rs similarity index 85% rename from src-tauri/src/module/downloader/honyaku13b.rs rename to src-tauri/src/module/downloader/model_dir.rs index 5cd5b52..415b73e 100644 --- a/src-tauri/src/module/downloader/honyaku13b.rs +++ b/src-tauri/src/module/downloader/model_dir.rs @@ -14,33 +14,32 @@ pub struct Progress { pub is_progress: bool, } -pub struct Honyaku13BModelDownloader { +pub struct ModelDirDownloader { app_handle: AppHandle, } -impl Honyaku13BModelDownloader { +impl ModelDirDownloader { pub fn new(app_handle: AppHandle) -> Self { Self { app_handle } } #[tokio::main] - pub async fn download(&self) { - let model_type = "honyaku13b-q4-0"; + pub async fn download(&self, model_type: &str, progress_identifier: &str) { let path: &str = &self .app_handle .path_resolver() - .resolve_resource("resources/honyaku13b-q4-0.zip") + .resolve_resource(format!("resources/{}.zip", model_type)) .unwrap() .to_string_lossy() .to_string(); - let url = "https://object-storage.tyo1.conoha.io/v1/nc_b22de95e3cf1434da07499038766e2b7/lycoris/honyaku13b-q4-0.zip"; - let res = reqwest::get(url).await.unwrap(); + let url = format!("https://lycoris-storage.wktk.dev/{}.zip", model_type); + let res = reqwest::get(url.clone()).await.unwrap(); let total_size = res .content_length() .ok_or(format!("Failed to get content length from '{}'", url)) .unwrap(); let _ = &self.app_handle.emit_all( - "downloadHonyaku13BProgress", + progress_identifier, Progress { model_type: model_type.to_string(), rate: 0.0, @@ -76,7 +75,7 @@ impl Honyaku13BModelDownloader { let current_rate = ((new as f64 * 100.0) / total_size as f64).round(); if rate != current_rate { let _ = &self.app_handle.emit_all( - "downloadHonyaku13BProgress", + progress_identifier, Progress { model_type: model_type.to_string(), rate: current_rate, @@ -108,7 +107,7 @@ impl Honyaku13BModelDownloader { let _ = Sqlite::new().update_model_is_downloaded(model_type.to_string(), 1); let _ = &self.app_handle.emit_all( - "downloadHonyaku13BProgress", + progress_identifier, Progress { model_type: model_type.to_string(), rate: 0.0, diff --git a/src-tauri/src/module/downloader/fugumt.rs b/src-tauri/src/module/downloader/sbv2.rs similarity index 86% rename from src-tauri/src/module/downloader/fugumt.rs rename to src-tauri/src/module/downloader/sbv2.rs index 5df67aa..d8c299e 100644 --- a/src-tauri/src/module/downloader/fugumt.rs +++ b/src-tauri/src/module/downloader/sbv2.rs @@ -14,25 +14,25 @@ pub struct Progress { pub is_progress: bool, } -pub struct FugumtModelDownloader { +pub struct StyleBertVits2ModelDownloader { app_handle: AppHandle, } -impl FugumtModelDownloader { +impl StyleBertVits2ModelDownloader { pub fn new(app_handle: AppHandle) -> Self { Self { app_handle } } #[tokio::main] pub async fn download(&self) { - let model_type = "fugumt-en-ja"; + let model_type = "style-bert-vits2"; let path: &str = &self .app_handle .path_resolver() - .resolve_resource("resources/fugumt-en-ja.zip") + .resolve_resource("resources/style-bert-vits/style-bert-vits.zip") .unwrap() .to_string_lossy() .to_string(); - let url = "https://object-storage.tyo1.conoha.io/v1/nc_b22de95e3cf1434da07499038766e2b7/lycoris/fugumt-en-ja.zip"; + let url = "https://lycoris-storage.wktk.dev/style-bert-vits.zip"; let res = reqwest::get(url).await.unwrap(); let total_size = res .content_length() @@ -40,7 +40,7 @@ impl FugumtModelDownloader { .unwrap(); let _ = &self.app_handle.emit_all( - "downloadFugumtProgress", + "downloadStyleBertVits2Progress", Progress { model_type: model_type.to_string(), rate: 0.0, @@ -76,7 +76,7 @@ impl FugumtModelDownloader { let current_rate = ((new as f64 * 100.0) / total_size as f64).round(); if rate != current_rate { let _ = &self.app_handle.emit_all( - "downloadFugumtProgress", + "downloadStyleBertVits2Progress", Progress { model_type: model_type.to_string(), rate: current_rate, @@ -90,7 +90,7 @@ impl FugumtModelDownloader { let dir: &str = &self .app_handle .path_resolver() - .resolve_resource("resources") + .resolve_resource("resources/style-bert-vits") .unwrap() .to_string_lossy() .to_string(); @@ -108,7 +108,7 @@ impl FugumtModelDownloader { let _ = Sqlite::new().update_model_is_downloaded(model_type.to_string(), 1); let _ = &self.app_handle.emit_all( - "downloadFugumtProgress", + "downloadStyleBertVits2Progress", Progress { model_type: model_type.to_string(), rate: 0.0, diff --git a/src-tauri/src/module/downloader/sbv2_voice.rs b/src-tauri/src/module/downloader/sbv2_voice.rs new file mode 100644 index 0000000..e711431 --- /dev/null +++ b/src-tauri/src/module/downloader/sbv2_voice.rs @@ -0,0 +1,105 @@ +use tauri::{AppHandle, Manager}; + +use futures_util::StreamExt; +use std::cmp::min; +use std::fs::File; +use std::io::Write; + +use crate::module::model_type_sbv2::ModelTypeStyleBertVits2; +use crate::module::sqlite::Sqlite; + +#[derive(Debug, Clone, serde::Serialize)] +pub struct Progress { + pub model_type: String, + pub rate: f64, + pub is_progress: bool, +} + +pub struct StyleBertVits2VoiceModelDownloader { + app_handle: AppHandle, +} +impl StyleBertVits2VoiceModelDownloader { + pub fn new(app_handle: AppHandle) -> Self { + Self { app_handle } + } + + #[tokio::main] + pub async fn download(&self, model_type: ModelTypeStyleBertVits2) { + let model_path: &str = &format!("resources/style-bert-vits/models/{}.sbv2", model_type.as_str()); + let path: &str = &self + .app_handle + .path_resolver() + .resolve_resource(model_path) + .unwrap() + .to_string_lossy() + .to_string(); + let url: &str = &format!( + "https://lycoris-storage.wktk.dev/{}.sbv2", + model_type.as_str() + ); + let res = reqwest::get(url).await.unwrap(); + let total_size = res + .content_length() + .ok_or(format!("Failed to get content length from '{}'", url)) + .unwrap(); + + let _ = &self.app_handle.emit_all( + "downloadStyleBertVits2VoiceProgress", + Progress { + model_type: model_type.as_str().to_string(), + rate: 0.0, + is_progress: true, + }, + ); + + let mut file; + let mut downloaded: u64 = 0; + let mut stream = res.bytes_stream(); + + println!("Seeking in file."); + if std::path::Path::new(&path).exists() { + println!("File exists. Removig..."); + let _ = std::fs::remove_file(&path); + } + file = File::create(&path) + .or(Err(format!("Failed to create file '{}'", &path))) + .unwrap(); + + println!("Commencing transfer"); + let mut rate = 0.0; + while let Some(item) = stream.next().await { + let chunk = item + .or(Err(format!("Error while downloading file"))) + .unwrap(); + file.write(&chunk) + .or(Err(format!("Error while writing to file"))) + .unwrap(); + let new = min(downloaded + (chunk.len() as u64), total_size); + downloaded = new; + + let current_rate = ((new as f64 * 100.0) / total_size as f64).round(); + if rate != current_rate { + let _ = &self.app_handle.emit_all( + "downloadStyleBertVits2VoiceProgress", + Progress { + model_type: model_type.as_str().to_string(), + rate: current_rate, + is_progress: true, + }, + ); + rate = current_rate + } + } + + let _ = Sqlite::new().update_model_is_downloaded(model_type.as_str().to_string(), 1); + + let _ = &self.app_handle.emit_all( + "downloadStyleBertVits2VoiceProgress", + Progress { + model_type: model_type.as_str().to_string(), + rate, + is_progress: false, + }, + ); + } +} diff --git a/src-tauri/src/module/downloader/vosk.rs b/src-tauri/src/module/downloader/vosk.rs index 8fbcf07..de0678d 100644 --- a/src-tauri/src/module/downloader/vosk.rs +++ b/src-tauri/src/module/downloader/vosk.rs @@ -34,7 +34,7 @@ impl VoskModelDownloader { .to_string_lossy() .to_string(); let url: &str = &format!( - "https://object-storage.tyo1.conoha.io/v1/nc_b22de95e3cf1434da07499038766e2b7/lycoris/vosk-model-{}.zip", + "https://lycoris-storage.wktk.dev/vosk-model-{}.zip", model_type.as_str() ); let res = reqwest::get(url).await.unwrap(); diff --git a/src-tauri/src/module/downloader/whisper.rs b/src-tauri/src/module/downloader/whisper.rs index 943846d..e283ff0 100644 --- a/src-tauri/src/module/downloader/whisper.rs +++ b/src-tauri/src/module/downloader/whisper.rs @@ -34,7 +34,7 @@ impl WhisperModelDownloader { .to_string_lossy() .to_string(); let url: &str = &format!( - "https://object-storage.tyo1.conoha.io/v1/nc_b22de95e3cf1434da07499038766e2b7/lycoris/ggml-{}.zip", + "https://lycoris-storage.wktk.dev/ggml-{}.zip", model_type.as_str() ); let res = reqwest::get(url).await.unwrap(); diff --git a/src-tauri/src/module/mod.rs b/src-tauri/src/module/mod.rs index 7f3de43..c4434b9 100644 --- a/src-tauri/src/module/mod.rs +++ b/src-tauri/src/module/mod.rs @@ -1,20 +1,23 @@ +pub mod action; pub mod chat_online; pub mod deleter; pub mod device; pub mod downloader; +pub mod model_type_sbv2; pub mod model_type_vosk; pub mod model_type_whisper; pub mod permissions; mod recognizer; pub mod record; pub mod record_desktop; +pub mod screenshot; mod sqlite; +pub mod synthesizer; mod transcriber; pub mod transcription; pub mod transcription_amivoice; pub mod transcription_online; +pub mod translation_en; pub mod translation_ja; pub mod translation_ja_high; mod writer; -pub mod screenshot; -pub mod action; diff --git a/src-tauri/src/module/model_type_sbv2.rs b/src-tauri/src/module/model_type_sbv2.rs new file mode 100644 index 0000000..ef6ed1e --- /dev/null +++ b/src-tauri/src/module/model_type_sbv2.rs @@ -0,0 +1,42 @@ +use std::str::FromStr; + +pub enum ModelTypeStyleBertVits2 { + TsukuyomiChan, + KoharuneAmi, + Amitaro, + JvnvF1Jp, + JvnvF2Jp, + JvnvM1Jp, + JvnvM2Jp, +} + +impl ModelTypeStyleBertVits2 { + pub fn as_str(&self) -> &'static str { + match self { + ModelTypeStyleBertVits2::TsukuyomiChan => "tsukuyomi-chan", + ModelTypeStyleBertVits2::KoharuneAmi => "koharune-ami", + ModelTypeStyleBertVits2::Amitaro => "amitaro", + ModelTypeStyleBertVits2::JvnvF1Jp => "jvnv-F1-jp", + ModelTypeStyleBertVits2::JvnvF2Jp => "jvnv-F2-jp", + ModelTypeStyleBertVits2::JvnvM1Jp => "jvnv-M1-jp", + ModelTypeStyleBertVits2::JvnvM2Jp => "jvnv-M2-jp", + } + } +} + +impl FromStr for ModelTypeStyleBertVits2 { + type Err = (); + + fn from_str(input: &str) -> Result { + match input { + "tsukuyomi-chan" => Ok(ModelTypeStyleBertVits2::TsukuyomiChan), + "koharune-ami" => Ok(ModelTypeStyleBertVits2::KoharuneAmi), + "amitaro" => Ok(ModelTypeStyleBertVits2::Amitaro), + "jvnv-F1-jp" => Ok(ModelTypeStyleBertVits2::JvnvF1Jp), + "jvnv-F2-jp" => Ok(ModelTypeStyleBertVits2::JvnvF2Jp), + "jvnv-M1-jp" => Ok(ModelTypeStyleBertVits2::JvnvM1Jp), + "jvnv-M2-jp" => Ok(ModelTypeStyleBertVits2::JvnvM2Jp), + _ => Err(()), + } + } +} diff --git a/src-tauri/src/module/model_type_whisper.rs b/src-tauri/src/module/model_type_whisper.rs index 7c1d3ad..c7e73c7 100644 --- a/src-tauri/src/module/model_type_whisper.rs +++ b/src-tauri/src/module/model_type_whisper.rs @@ -4,8 +4,10 @@ pub enum ModelTypeWhisper { Base, BaseEn, Large, + LargeTurbo, LargeDistilEn, LargeDistilJa, + LargeDistilBilingual, Medium, MediumEn, Small, @@ -20,8 +22,10 @@ impl ModelTypeWhisper { ModelTypeWhisper::Base => "base", ModelTypeWhisper::BaseEn => "base.en", ModelTypeWhisper::Large => "large", + ModelTypeWhisper::LargeTurbo => "large-turbo", ModelTypeWhisper::LargeDistilEn => "large-distil.en", ModelTypeWhisper::LargeDistilJa => "large-distil.ja", + ModelTypeWhisper::LargeDistilBilingual => "large-distil.bilingual", ModelTypeWhisper::Medium => "medium", ModelTypeWhisper::MediumEn => "medium.en", ModelTypeWhisper::Small => "small", @@ -40,8 +44,10 @@ impl FromStr for ModelTypeWhisper { "base" => Ok(ModelTypeWhisper::Base), "base.en" => Ok(ModelTypeWhisper::BaseEn), "large" => Ok(ModelTypeWhisper::Large), + "large-turbo" => Ok(ModelTypeWhisper::LargeTurbo), "large-distil.en" => Ok(ModelTypeWhisper::LargeDistilEn), "large-distil.ja" => Ok(ModelTypeWhisper::LargeDistilJa), + "large-distil.bilingual" => Ok(ModelTypeWhisper::LargeDistilBilingual), "medium" => Ok(ModelTypeWhisper::Medium), "medium.en" => Ok(ModelTypeWhisper::MediumEn), "small" => Ok(ModelTypeWhisper::Small), diff --git a/src-tauri/src/module/record.rs b/src-tauri/src/module/record.rs index cd88920..037ed1f 100644 --- a/src-tauri/src/module/record.rs +++ b/src-tauri/src/module/record.rs @@ -24,7 +24,7 @@ use tauri::{api::path::data_dir, AppHandle, Manager}; use super::{ chat_online, recognizer::MyRecognizer, sqlite::Sqlite, transcription, transcription_amivoice, - transcription_online, translation_ja, translation_ja_high, writer::Writer, + transcription_online, translation_en, translation_ja, translation_ja_high, writer::Writer, }; pub struct Record { @@ -228,7 +228,16 @@ impl Record { if let Some(singleton) = lock.as_mut() { singleton.start(stop_convert_rx_clone, false); } - } else if transcription_accuracy_clone.starts_with("honyaku13b-q4-0") { + } else if transcription_accuracy_clone.starts_with("fugumt-ja-en") { + translation_en::initialize_translation_en( + app_handle_clone, + note_id, + ); + let mut lock = translation_en::SINGLETON_INSTANCE.lock().unwrap(); + if let Some(singleton) = lock.as_mut() { + singleton.start(stop_convert_rx_clone, false); + } + } else if transcription_accuracy_clone.starts_with("honyaku-13b") { translation_ja_high::initialize_translation_ja_high( app_handle_clone, speaker_language_clone, @@ -275,6 +284,7 @@ impl Record { if !is_no_transcription { stop_convert_tx.send(()).unwrap(); transcription::drop_transcription(); + translation_en::drop_translation_en(); translation_ja::drop_translation_ja(); translation_ja_high::drop_translation_ja_high(); transcription_online::drop_transcription_online(); diff --git a/src-tauri/src/module/record_desktop.rs b/src-tauri/src/module/record_desktop.rs index 005926a..c94cd25 100644 --- a/src-tauri/src/module/record_desktop.rs +++ b/src-tauri/src/module/record_desktop.rs @@ -39,7 +39,7 @@ use vosk::Recognizer; use super::{ chat_online, recognizer::MyRecognizer, sqlite::Sqlite, transcription, transcription_amivoice, - transcription_online, translation_ja, translation_ja_high, writer::Writer, + transcription_online, translation_en, translation_ja, translation_ja_high, writer::Writer, }; pub struct RecordDesktop { @@ -260,7 +260,16 @@ impl RecordDesktop { if let Some(singleton) = lock.as_mut() { singleton.start(stop_convert_rx_clone, false); } - } else if transcription_accuracy_clone.starts_with("honyaku13b-q4-0") { + } else if transcription_accuracy_clone.starts_with("fugumt-ja-en") { + translation_en::initialize_translation_en( + app_handle_clone, + note_id, + ); + let mut lock = translation_en::SINGLETON_INSTANCE.lock().unwrap(); + if let Some(singleton) = lock.as_mut() { + singleton.start(stop_convert_rx_clone, false); + } + } else if transcription_accuracy_clone.starts_with("honyaku-13b") { translation_ja_high::initialize_translation_ja_high( app_handle_clone, speaker_language_clone, @@ -311,6 +320,7 @@ impl RecordDesktop { if !is_no_transcription { stop_convert_tx.send(()).unwrap(); transcription::drop_transcription(); + translation_en::drop_translation_en(); translation_ja::drop_translation_ja(); translation_ja_high::drop_translation_ja_high(); transcription_online::drop_transcription_online(); diff --git a/src-tauri/src/module/sqlite.rs b/src-tauri/src/module/sqlite.rs index 3aac31c..a57701e 100644 --- a/src-tauri/src/module/sqlite.rs +++ b/src-tauri/src/module/sqlite.rs @@ -29,12 +29,14 @@ pub struct Updated { #[derive(Debug, Clone, serde::Serialize)] pub struct UnexecutedAction { pub id: u16, + pub action_type: String, pub content: String, } #[derive(Debug, Clone, serde::Serialize)] pub struct Content { pub speech_type: String, + pub action_type: String, pub content: String, pub content_2: String, } @@ -195,13 +197,14 @@ impl Sqlite { note_id: u64, id: u16, ) -> Result, rusqlite::Error> { - let mut stmt = self.conn.prepare("SELECT speech_type,content,content_2 FROM speeches WHERE note_id = ?1 AND id < ?2 ORDER BY created_at_unixtime ASC").unwrap(); + let mut stmt = self.conn.prepare("SELECT speech_type,action_type,content,content_2 FROM speeches WHERE note_id = ?1 AND id < ?2 ORDER BY created_at_unixtime ASC").unwrap(); let results = stmt .query_map(params![note_id, id], |row| { Ok(Content { speech_type: row.get_unwrap(0), - content: row.get_unwrap(1), - content_2: row.get(2).unwrap_or_default(), + action_type: row.get(1).unwrap_or_default(), + content: row.get_unwrap(2), + content_2: row.get(3).unwrap_or_default(), }) }) .unwrap() @@ -213,9 +216,9 @@ impl Sqlite { &self, note_id: u64, ) -> Result { - return self.conn.query_row("SELECT id, content FROM speeches WHERE speech_type = \"action\" AND content_2 IS NULL AND note_id = ?1 ORDER BY created_at_unixtime ASC LIMIT 1", + return self.conn.query_row("SELECT id, action_type, content FROM speeches WHERE speech_type = \"action\" AND content_2 IS NULL AND note_id = ?1 ORDER BY created_at_unixtime ASC LIMIT 1", params![note_id], - |row| Ok(UnexecutedAction{id: row.get_unwrap(0), content: row.get_unwrap(1)}), + |row| Ok(UnexecutedAction{id: row.get_unwrap(0), action_type: row.get_unwrap(1), content: row.get_unwrap(2)}), ); } diff --git a/src-tauri/src/module/synthesizer.rs b/src-tauri/src/module/synthesizer.rs new file mode 100644 index 0000000..51fcfbb --- /dev/null +++ b/src-tauri/src/module/synthesizer.rs @@ -0,0 +1,96 @@ +use std::fs; + +use sbv2_core::tts::TTSModelHolder; +use tauri::AppHandle; + +pub fn list_models(app_handle: AppHandle) -> Vec { + let models_path = app_handle + .path_resolver() + .resolve_resource("resources/style-bert-vits/models".to_string()) + .unwrap() + .to_string_lossy() + .to_string(); + + let dir = fs::read_dir(models_path).unwrap(); + let mut models: Vec = Vec::new(); + for item in dir.into_iter() { + let name = item.unwrap().file_name().to_string_lossy().to_string(); + if name.ends_with(".sbv2") { + let entry = &name[..name.len() - 5]; + models.push(entry.to_string()); + } + } + + models +} + +pub struct Synthesizer { + ident: String, + tts_model: TTSModelHolder, +} + +impl Synthesizer { + pub fn new(app_handle: AppHandle, model: String) -> Self { + let bert_model_path = app_handle + .path_resolver() + .resolve_resource(format!("resources/style-bert-vits/{}", "deberta.onnx")) + .unwrap() + .to_string_lossy() + .to_string(); + let tokenizer_path = app_handle + .path_resolver() + .resolve_resource(format!("resources/style-bert-vits/{}", "tokenizer.json")) + .unwrap() + .to_string_lossy() + .to_string(); + let models_path = app_handle + .path_resolver() + .resolve_resource("resources/style-bert-vits/models".to_string()) + .unwrap() + .to_string_lossy() + .to_string(); + + let mut tts_model = TTSModelHolder::new( + &fs::read(bert_model_path).unwrap(), + &fs::read(tokenizer_path).unwrap(), + ) + .unwrap(); + + let sbv2_bytes = fs::read(format!("{models_path}/{}.sbv2", model.clone())).unwrap(); + let _ = tts_model.load_sbv2file(model.clone(), sbv2_bytes); + + Self { + ident: model, + tts_model, + } + } + + pub fn synthesize( + &mut self, + text: String, + sdp_ratio: f32, + length_scale: f32, + ) -> Result, String> { + let (bert_ori, phones, tones, lang_ids) = self.tts_model.parse_text(&text).unwrap(); + + let style_vector = self + .tts_model + .get_style_vector(self.ident.clone(), 0, 1.0) + .unwrap(); + let buffer = self + .tts_model + .synthesize( + self.ident.clone(), + bert_ori.to_owned(), + phones, + tones, + lang_ids, + style_vector, + sdp_ratio, + length_scale, + ) + .unwrap(); + + Ok(buffer) + } +} diff --git a/src-tauri/src/module/transcriber.rs b/src-tauri/src/module/transcriber.rs index c0c11e1..1b245a9 100644 --- a/src-tauri/src/module/transcriber.rs +++ b/src-tauri/src/module/transcriber.rs @@ -16,6 +16,10 @@ impl Transcriber { model_type = "large-distil.en" } else if transcription_accuracy.starts_with("large-distil.ja") { model_type = "large-distil.ja" + } else if transcription_accuracy.starts_with("large-distil.bilingual") { + model_type = "large-distil.bilingual" + } else if transcription_accuracy.starts_with("large-turbo") { + model_type = "large-turbo" } else if transcription_accuracy.starts_with("large") { model_type = "large" } @@ -26,8 +30,14 @@ impl Transcriber { .to_string_lossy() .to_string(); - return WhisperContext::new_with_params(&model_path, WhisperContextParameters::default()) - .expect("failed to load whisper model"); + return WhisperContext::new_with_params( + &model_path, + WhisperContextParameters { + flash_attn: true, + ..WhisperContextParameters::default() + }, + ) + .expect("failed to load whisper model"); } pub fn build_params( @@ -85,12 +95,64 @@ impl Transcriber { ); println!("working on {} threads.", hardware_concurrency); params.set_n_threads(hardware_concurrency); - if transcription_accuracy.ends_with("en") { + + if transcription_accuracy.starts_with("large-distil.bilingual") { params.set_translate(true); + if language == "en" { + params.set_initial_prompt("こんにちは、私の講義へようこそ。"); + params.set_language(Some("ja")); + } else if language == "ja" { + params.set_initial_prompt("Hello, welcome to my lecture."); + params.set_language(Some("en")); + } } else { - params.set_translate(false); + params.set_language(Some(language)); + if transcription_accuracy.ends_with("en") { + params.set_translate(true); + params.set_initial_prompt("Hello, welcome to my lecture."); + } else { + params.set_translate(false); + if language == "en" { + params.set_initial_prompt("Hello, welcome to my lecture."); + } else if language == "zh" { + params.set_initial_prompt("你好,欢迎来到我的讲座。"); + } else if language == "ko" { + params.set_initial_prompt("안녕하세요, 제 강의에 오신 것을 환영합니다."); + } else if language == "fr" { + params.set_initial_prompt("Bonjour, bienvenue à mon cours."); + } else if language == "de" { + params.set_initial_prompt("Hallo, willkommen zu meiner Vorlesung."); + } else if language == "ru" { + params.set_initial_prompt("Привет, добро пожаловать на мою лекцию."); + } else if language == "es" { + params.set_initial_prompt("Hola, bienvenido a mi conferencia."); + } else if language == "pt" { + params.set_initial_prompt("Olá, bem-vindo à minha palestra."); + } else if language == "tr" { + params.set_initial_prompt("Merhaba, dersime hoş geldiniz."); + } else if language == "vi" { + params.set_initial_prompt("Xin chào, chào mừng bạn đến với bài giảng của tôi."); + } else if language == "it" { + params.set_initial_prompt("Ciao, benvenuto alla mia conferenza."); + } else if language == "nl" { + params.set_initial_prompt("Hallo, welkom bij mijn lezing."); + } else if language == "ca" { + params.set_initial_prompt("Hola, benvingut a la meva conferència."); + } else if language == "uk" { + params.set_initial_prompt("Привіт, ласкаво просимо на мою лекцію."); + } else if language == "sv" { + params.set_initial_prompt("Hej, välkommen till min föreläsning."); + } else if language == "hi" { + params.set_initial_prompt("नमस्ते, मेरे व्याख्यान में आपका स्वागत है।"); + } else if language == "cs" { + params.set_initial_prompt("Ahoj, vítejte na mé přednášce."); + } else if language == "pl" { + params.set_initial_prompt("Cześć, witaj na mojej wykładzie."); + } else if language == "ja" { + params.set_initial_prompt("こんにちは、私の講義へようこそ。"); + } + } } - params.set_language(Some(language)); params.set_print_special(false); params.set_print_progress(false); params.set_print_realtime(false); diff --git a/src-tauri/src/module/transcription_online.rs b/src-tauri/src/module/transcription_online.rs index 4ba0b3c..1b80fb0 100644 --- a/src-tauri/src/module/transcription_online.rs +++ b/src-tauri/src/module/transcription_online.rs @@ -146,16 +146,64 @@ impl TranscriptionOnline { "ja" }; let part_language = multipart::Part::text(language); + let prompt = if is_translate { + "Hello, welcome to my lecture." + } else { + if language == "en" { + "Hello, welcome to my lecture." + } else if language == "zh" { + "你好,欢迎来到我的讲座。" + } else if language == "ko" { + "안녕하세요, 제 강의에 오신 것을 환영합니다." + } else if language == "fr" { + "Bonjour, bienvenue à mon cours." + } else if language == "de" { + "Hallo, willkommen zu meiner Vorlesung." + } else if language == "ru" { + "Привет, добро пожаловать на мою лекцию." + } else if language == "es" { + "Hola, bienvenido a mi conferencia." + } else if language == "pt" { + "Olá, bem-vindo à minha palestra." + } else if language == "tr" { + "Merhaba, dersime hoş geldiniz." + } else if language == "vi" { + "Xin chào, chào mừng bạn đến với bài giảng của tôi." + } else if language == "it" { + "Ciao, benvenuto alla mia conferenza." + } else if language == "nl" { + "Hallo, welkom bij mijn lezing." + } else if language == "ca" { + "Hola, benvingut a la meva conferència." + } else if language == "uk" { + "Привіт, ласкаво просимо на мою лекцію." + } else if language == "sv" { + "Hej, välkommen till min föreläsning." + } else if language == "hi" { + "नमस्ते, मेरे व्याख्यान में आपका स्वागत है।" + } else if language == "cs" { + "Ahoj, vítejte na mé přednášce." + } else if language == "pl" { + "Cześć, witaj na mojej wykładzie." + } else if language == "ja" { + "こんにちは、私の講義へようこそ。" + } else { + "Hello, welcome to my lecture." + } + }; + let part_prompt = multipart::Part::text(prompt); let form = if is_translate { multipart::Form::new() .part("file", part_file) .part("model", part_model) + .part("prompt", part_prompt) } else { multipart::Form::new() .part("file", part_file) .part("model", part_model) .part("language", part_language) + .part("prompt", part_prompt) }; let response = client diff --git a/src-tauri/src/module/translation_en.rs b/src-tauri/src/module/translation_en.rs new file mode 100644 index 0000000..8865e1d --- /dev/null +++ b/src-tauri/src/module/translation_en.rs @@ -0,0 +1,190 @@ +use super::{sqlite::Sqlite, transcriber::Transcriber}; + +use crossbeam_channel::Receiver; +use ct2rs::{tokenizers::auto::Tokenizer, Config, TranslationOptions, Translator}; +use hound::SampleFormat; +use samplerate_rs::{convert, ConverterType}; +use std::sync::Mutex; +use tauri::{AppHandle, Manager}; +use whisper_rs::WhisperContext; + +#[derive(Debug, Clone, serde::Serialize)] +pub struct TraceCompletion {} + +pub struct TranslationEn { + app_handle: AppHandle, + sqlite: Sqlite, + ctx: WhisperContext, + translator: Translator, + note_id: u64, +} + +impl TranslationEn { + pub fn new(app_handle: AppHandle, note_id: u64) -> Self { + let app_handle_clone = app_handle.clone(); + let model_path = app_handle + .path_resolver() + .resolve_resource(format!("resources/fugumt-ja-en")) + .unwrap() + .to_string_lossy() + .to_string(); + + TranslationEn { + app_handle, + sqlite: Sqlite::new(), + ctx: Transcriber::build(app_handle_clone, "large".to_string()), + translator: Translator::new(&model_path, &Config::default()).unwrap(), + note_id, + } + } + + pub fn start(&mut self, stop_convert_rx: Receiver<()>, is_continuous: bool) { + while Self::convert(self).is_ok() { + if is_continuous { + let vosk_speech = self.sqlite.select_vosk(self.note_id); + if vosk_speech.is_err() { + self.app_handle + .clone() + .emit_all("traceCompletion", TraceCompletion {}) + .unwrap(); + break; + } + } + if stop_convert_rx.try_recv().is_ok() { + let vosk_speech = self.sqlite.select_vosk(self.note_id); + if vosk_speech.is_err() { + self.app_handle + .clone() + .emit_all("traceCompletion", TraceCompletion {}) + .unwrap(); + } else { + self.app_handle + .clone() + .emit_all("traceUnCompletion", TraceCompletion {}) + .unwrap(); + } + break; + } + } + } + + fn convert(&mut self) -> Result<(), rusqlite::Error> { + let vosk_speech = self.sqlite.select_vosk(self.note_id); + return vosk_speech.and_then(|speech| { + let mut reader = hound::WavReader::open(speech.wav).unwrap(); + + let spec = reader.spec(); + let mut data = + Vec::with_capacity((spec.channels as usize) * (reader.duration() as usize)); + match (spec.bits_per_sample, spec.sample_format) { + (16, SampleFormat::Int) => { + for sample in reader.samples::() { + data.push((sample.unwrap() as f32) / (0x7fffi32 as f32)); + } + } + (24, SampleFormat::Int) => { + for sample in reader.samples::() { + let val = (sample.unwrap() as f32) / (0x00ff_ffffi32 as f32); + data.push(val); + } + } + (32, SampleFormat::Int) => { + for sample in reader.samples::() { + data.push((sample.unwrap() as f32) / (0x7fff_ffffi32 as f32)); + } + } + (32, SampleFormat::Float) => { + for sample in reader.samples::() { + data.push(sample.unwrap()); + } + } + _ => panic!( + "Tried to read file but there was a problem: {:?}", + hound::Error::Unsupported + ), + } + let data = if spec.channels != 1 { + whisper_rs::convert_stereo_to_mono_audio(&data).unwrap() + } else { + data + }; + let audio_data = convert( + spec.sample_rate, + 16000, + 1, + ConverterType::SincBestQuality, + &data, + ) + .unwrap(); + + let mut state = self.ctx.create_state().expect("failed to create state"); + let result = state.full( + Transcriber::build_params( + "ja".to_string(), + "large".to_string(), + ), + &audio_data[..], + ); + if result.is_ok() { + let num_segments = state + .full_n_segments() + .expect("failed to get number of segments"); + let mut converted: Vec = vec!["".to_string()]; + for i in 0..num_segments { + let segment = state.full_get_segment_text(i); + if segment.is_ok() { + converted.push(segment.unwrap().to_string()); + }; + } + + let result_on_whisper = converted.join(""); + let sources: Vec = result_on_whisper.lines().map(String::from).collect(); + let res: Vec<(String, Option)> = self + .translator + .translate_batch( + &sources, + &TranslationOptions { + beam_size: 5, + ..Default::default() + }, + None, + ) + .unwrap(); + let mut translated: Vec = vec!["".to_string()]; + for (r, _) in res { + translated.push(r); + } + + let updated = self + .sqlite + .update_model_vosk_to_whisper(speech.id, translated.join("")); + + let updated = updated.unwrap(); + if updated.content != "" { + self.app_handle + .clone() + .emit_all("finalTextConverted", updated) + .unwrap(); + } + } else { + println!("whisper is temporally failed, so skipping...") + } + + Ok(()) + }); + } +} + +pub static SINGLETON_INSTANCE: Mutex> = Mutex::new(None); + +pub fn initialize_translation_en(app_handle: AppHandle, note_id: u64) { + let mut singleton = SINGLETON_INSTANCE.lock().unwrap(); + if singleton.is_none() { + *singleton = Some(TranslationEn::new(app_handle, note_id)); + } +} + +pub fn drop_translation_en() { + let mut singleton = SINGLETON_INSTANCE.lock().unwrap(); + *singleton = None; +} diff --git a/src-tauri/src/module/translation_ja.rs b/src-tauri/src/module/translation_ja.rs index aebaa57..15b0b35 100644 --- a/src-tauri/src/module/translation_ja.rs +++ b/src-tauri/src/module/translation_ja.rs @@ -1,7 +1,7 @@ use super::{sqlite::Sqlite, transcriber::Transcriber}; use crossbeam_channel::Receiver; -use ct2rs::{config::Config, sentencepiece::Tokenizer, TranslationOptions, Translator}; +use ct2rs::{tokenizers::auto::Tokenizer, Config, TranslationOptions, Translator}; use hound::SampleFormat; use samplerate_rs::{convert, ConverterType}; use std::sync::Mutex; @@ -36,7 +36,6 @@ impl TranslationJa { ctx: Transcriber::build(app_handle_clone, "large-translate-to-en".to_string()), translator: Translator::new( &model_path, - Tokenizer::new(&model_path).unwrap(), &Config::default(), ) .unwrap(), @@ -154,6 +153,7 @@ impl TranslationJa { beam_size: 5, ..Default::default() }, + None ) .unwrap(); let mut translated: Vec = vec!["".to_string()]; diff --git a/src-tauri/src/module/translation_ja_high.rs b/src-tauri/src/module/translation_ja_high.rs index 1f13d4d..fe112c1 100644 --- a/src-tauri/src/module/translation_ja_high.rs +++ b/src-tauri/src/module/translation_ja_high.rs @@ -3,12 +3,15 @@ use super::{sqlite::Sqlite, transcriber::Transcriber}; use crossbeam_channel::Receiver; use hound::SampleFormat; use mistralrs::{ - Constraint, DefaultSchedulerMethod, Device, DeviceMapMetadata, GGUFLoaderBuilder, - GGUFSpecificConfig, MistralRs, MistralRsBuilder, ModelDType, NormalRequest, Request, - RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, TokenSource, + Constraint, DefaultSchedulerMethod, Device, DeviceMapMetadata, MistralRs, MistralRsBuilder, + ModelDType, NormalLoaderBuilder, NormalLoaderType, NormalRequest, NormalSpecificConfig, + Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, TokenSource, }; use samplerate_rs::{convert, ConverterType}; -use std::sync::{Arc, Mutex}; +use std::{ + path::PathBuf, + sync::{Arc, Mutex}, +}; use tauri::{AppHandle, Manager}; use tokio::sync::mpsc::channel; use whisper_rs::WhisperContext; @@ -30,24 +33,30 @@ impl TranslationJaHigh { let app_handle_clone = app_handle.clone(); let model_path = app_handle .path_resolver() - .resolve_resource(format!("resources/honyaku13b-q4-0")) + .resolve_resource(format!("resources/honyaku-13b")) .unwrap() .to_string_lossy() .to_string(); - let loader = GGUFLoaderBuilder::new( - Some(format!("{}/chat_templates_llama2.json", model_path)), - None, - model_path, - vec!["aixsatoshi-Honyaku-13b-Q4_0.gguf".to_string()], - // vec!["aixsatoshi-Honyaku-13b-IQ4_XS.gguf".to_string()], - GGUFSpecificConfig { + let loader = NormalLoaderBuilder::new( + NormalSpecificConfig { + use_flash_attn: false, prompt_batchsize: None, topology: None, + organization: Default::default(), + write_uqff: None, + from_uqff: Some(PathBuf::from(format!( + "{}/Honyaku-13b-q4_0.uqff", + model_path + ))), }, + None, + None, + Some(model_path), ) - .build(); - let pipeline = tokio::task::block_in_place(|| { - loader.load_model_from_hf( + .build(Some(NormalLoaderType::Llama)) + .unwrap(); + let pipeline = loader + .load_model_from_hf( None, TokenSource::None, &ModelDType::Auto, @@ -57,8 +66,7 @@ impl TranslationJaHigh { None, None, ) - }) - .unwrap(); + .unwrap(); TranslationJaHigh { app_handle, @@ -70,6 +78,7 @@ impl TranslationJaHigh { method: DefaultSchedulerMethod::Fixed(5.try_into().unwrap()), }, ) + .with_no_prefix_cache(true) .build(), speaker_language, note_id, @@ -177,14 +186,14 @@ impl TranslationJaHigh { let result_on_whisper = converted.join(""); let prompt = format!(": {} \n\n: ", result_on_whisper); - let (tx, mut rx) = channel(10_000); + let (tx, mut rx) = channel(1); let request = Request::Normal(NormalRequest { messages: RequestMessage::Completion { text: prompt, echo_prompt: false, best_of: 1, }, - sampling_params: SamplingParams::default(), + sampling_params: SamplingParams::deterministic(), response: tx, return_logprobs: false, is_streaming: false, diff --git a/src-tauri/tauri.conf.json b/src-tauri/tauri.conf.json index 2a64283..6af6636 100644 --- a/src-tauri/tauri.conf.json +++ b/src-tauri/tauri.conf.json @@ -8,7 +8,7 @@ }, "package": { "productName": "Lycoris", - "version": "0.9.21" + "version": "0.9.22" }, "tauri": { "allowlist": { @@ -60,9 +60,9 @@ "windows": [ { "fullscreen": false, - "width": 1024, + "width": 1180, "height": 768, - "minWidth": 1024, + "minWidth": 1180, "minHeight": 384, "resizable": true, "title": "Lycoris" diff --git a/src/components/Header.tsx b/src/components/Header.tsx index 1944202..e632ddd 100644 --- a/src/components/Header.tsx +++ b/src/components/Header.tsx @@ -1,4 +1,4 @@ -import { useSetRecoilState } from "recoil" +import { useRecoilValue, useSetRecoilState } from "recoil" import { getVersion } from '@tauri-apps/api/app'; import { featureState } from "../store/atoms/featureState" import { selectedNoteState } from "../store/atoms/selectedNoteState" @@ -6,6 +6,8 @@ import { AudioDevices } from "./molecules/AudioDevice" import { SpeakerLanguage } from "./molecules/SpeakerLanguage" import { TranscriptionAccuracy } from "./molecules/TranscriptionAccuracy" import { useEffect, useState } from "react"; +import { SmartVoice } from "./molecules/SmartVoice"; +import { modelStyleBertVits2DownloadedState } from "../store/atoms/modelStyleBertVits2DownloadedState"; const Header = (): JSX.Element => { const setFeature = useSetRecoilState(featureState) @@ -19,15 +21,20 @@ const Header = (): JSX.Element => { fetchVersion(); }, []); + const downloadedBaseModels = useRecoilValue(modelStyleBertVits2DownloadedState); + const is_base_downloaded = downloadedBaseModels.filter(m => m === "style-bert-vits2").length > 0 + return (
- Lycoris -

v{appVersion}

+
+ Lycoris +

v{appVersion}

+
- + @@ -36,7 +43,7 @@ const Header = (): JSX.Element => {
- +
@@ -44,14 +51,30 @@ const Header = (): JSX.Element => {
- +
-
- -
+ {is_base_downloaded ? + <> +
+ +
+
+ + + +
+
+ +
+ + : +
+ +
+ }
{ setFeature("settings"); setSelectedNote(null); }}> diff --git a/src/components/molecules/ActionSet.tsx b/src/components/molecules/ActionSet.tsx index fc0ab52..fea59bc 100644 --- a/src/components/molecules/ActionSet.tsx +++ b/src/components/molecules/ActionSet.tsx @@ -1,10 +1,12 @@ import { useRef, useState } from "react" +import { useRecoilState } from "recoil" +import { actionState } from "../../store/atoms/actionState" const ActionSet = (): JSX.Element => { const dropdownRef = useRef(null) - const [targetAction, setTargetAction] = useState("チャット") - const actions = ["チャット"] + const [targetAction, setTargetAction] = useRecoilState(actionState) + const actions = ["チャット", "発話サジェスト"] const [toggle, setToggle] = useState(false) const change = (actionName: string) => { diff --git a/src/components/molecules/AppWindow.tsx b/src/components/molecules/AppWindow.tsx index 54f4056..a275c31 100644 --- a/src/components/molecules/AppWindow.tsx +++ b/src/components/molecules/AppWindow.tsx @@ -5,6 +5,8 @@ import { AppWindowType } from "../../type/AppWindow.type" import { appWindowState } from "../../store/atoms/appWindowState" import { useHasPermissionScreenCapture } from "../../hooks/useHasPermissionScreenCapture" import { ScreenShotButton } from "./ScreenshotButton" +import { appSelectedState } from "../../store/atoms/appSelectedState" +import { appWindowsState } from "../../store/atoms/appWindowsState" const AppWindow = (): JSX.Element => { const [isDesktopAudioToggled, setIsDesktopAudioToggled] = useState(null) @@ -12,25 +14,27 @@ const AppWindow = (): JSX.Element => { const dropdownRef = useRef(null) - const [targetApp, setTargetApp] = useState(null) - const [apps, setApps] = useState([]) + const [targetApp, setTargetApp] = useRecoilState(appSelectedState) + const [apps, setApps] = useState([] as string[]) const [toggle, setToggle] = useState(false) useEffect(() => { invoke('list_apps_command').then(apps => setApps(apps as string[])) }, [toggle]) const [targetWindow, setTargetWindow] = useRecoilState(appWindowState) - const [appWindows, setAppWindows] = useState([]) - const change = (e: ChangeEvent) => { + const [appWindows, setAppWindows] = useRecoilState(appWindowsState) + const change = (appWindowId: number) => { dropdownRef.current?.focus(); - if (e.target.checked) { - const appWindowId = e.target.value - const targetAppWindow = appWindows.filter(({ id }) => id === parseInt(appWindowId))[0] - setTargetWindow(targetAppWindow) + const targetAppWindows = appWindows.filter(({ id }) => id === appWindowId) + if (targetAppWindows.length > 1) { + setTargetWindow(null) + } else { + setTargetWindow(targetAppWindows[0]) } } const click = (appName: string) => { setTargetApp(appName) + setTargetWindow(null) invoke('list_app_windows_command', { appName }) .then(windows => setAppWindows(windows as AppWindowType[])) } @@ -42,48 +46,50 @@ const AppWindow = (): JSX.Element => { } } - return (<> -
- setIsDesktopAudioToggled(!isDesktopAudioToggled) - }> -
+
+ + ) } diff --git a/src/components/molecules/ModelDownloadFugumtEnJaButton.tsx b/src/components/molecules/ModelDownloadFugumtEnJaButton.tsx new file mode 100644 index 0000000..81261e4 --- /dev/null +++ b/src/components/molecules/ModelDownloadFugumtEnJaButton.tsx @@ -0,0 +1,30 @@ +import { invoke } from '@tauri-apps/api/tauri' +import { useRecoilState, useRecoilValue } from 'recoil' +import { modelFugumtEnJaDownloadingState } from '../../store/atoms/modelFugumtEnJaDownloadingState' +import { modelFugumtEnJaDownloadedState } from '../../store/atoms/modelFugumtEnJaDownloadedState' +import { modelWhisperDownloadedState } from '../../store/atoms/modelWhisperDownloadedState' + +const ModelDownloadFugumtEnJaButton = (): JSX.Element => { + const modelType = "fugumt-en-ja" + const downloadedModels = useRecoilValue(modelFugumtEnJaDownloadedState) + const downloadedBaseModels = useRecoilValue(modelWhisperDownloadedState) + const [downloadingModels, setDownloadingModels] = useRecoilState(modelFugumtEnJaDownloadingState) + const click = () => { + setDownloadingModels([...downloadingModels, modelType]) + invoke('download_fugumt_enja_model_command') + } + const is_downloaded = downloadedModels.filter(m => m === modelType).length > 0 + const is_downloading = downloadingModels.filter(m => m === modelType).length > 0 + const is_base_downloaded = downloadedBaseModels.filter(m => m === "large").length > 0 + + return ( + + ) +} + +export { ModelDownloadFugumtEnJaButton } \ No newline at end of file diff --git a/src/components/molecules/ModelDownloadFugumtProgress.tsx b/src/components/molecules/ModelDownloadFugumtEnJaProgress.tsx similarity index 73% rename from src/components/molecules/ModelDownloadFugumtProgress.tsx rename to src/components/molecules/ModelDownloadFugumtEnJaProgress.tsx index 33ca420..86c7742 100644 --- a/src/components/molecules/ModelDownloadFugumtProgress.tsx +++ b/src/components/molecules/ModelDownloadFugumtEnJaProgress.tsx @@ -1,21 +1,21 @@ import { useRecoilState, useSetRecoilState } from 'recoil' -import { modelFugumtDownloadingState } from '../../store/atoms/modelFugumtDownloadingState' +import { modelFugumtEnJaDownloadingState } from '../../store/atoms/modelFugumtEnJaDownloadingState' import { listen } from '@tauri-apps/api/event' import { useEffect, useState } from 'react' import { ProgressType } from '../../type/progress.type' -import { modelFugumtDownloadedState } from '../../store/atoms/modelFugumtDownloadedState' +import { modelFugumtEnJaDownloadedState } from '../../store/atoms/modelFugumtEnJaDownloadedState' -const ModelDownloadFugumtProgress = (): JSX.Element => { +const ModelDownloadFugumtEnJaProgress = (): JSX.Element => { const modelType = "fugumt-en-ja" - const setDownloadedModel = useSetRecoilState(modelFugumtDownloadedState) - const [downloadingModels, setDownloadingModels] = useRecoilState(modelFugumtDownloadingState) + const setDownloadedModel = useSetRecoilState(modelFugumtEnJaDownloadedState) + const [downloadingModels, setDownloadingModels] = useRecoilState(modelFugumtEnJaDownloadingState) const [progress, setProgress] = useState({ model_type: modelType, rate: 0, is_progress: false }) useEffect(() => { - const unlisten = listen('downloadFugumtProgress', event => { + const unlisten = listen('downloadFugumtEnJaProgress', event => { const p = event.payload as ProgressType if (p.model_type === modelType) { setProgress(p) @@ -39,4 +39,4 @@ const ModelDownloadFugumtProgress = (): JSX.Element => { return (<>) } -export { ModelDownloadFugumtProgress } \ No newline at end of file +export { ModelDownloadFugumtEnJaProgress } \ No newline at end of file diff --git a/src/components/molecules/ModelDownloadFugumtJaEnButton.tsx b/src/components/molecules/ModelDownloadFugumtJaEnButton.tsx new file mode 100644 index 0000000..a1a7f6f --- /dev/null +++ b/src/components/molecules/ModelDownloadFugumtJaEnButton.tsx @@ -0,0 +1,30 @@ +import { invoke } from '@tauri-apps/api/tauri' +import { useRecoilState, useRecoilValue } from 'recoil' +import { modelFugumtJaEnDownloadingState } from '../../store/atoms/modelFugumtJaEnDownloadingState' +import { modelFugumtJaEnDownloadedState } from '../../store/atoms/modelFugumtJaEnDownloadedState' +import { modelWhisperDownloadedState } from '../../store/atoms/modelWhisperDownloadedState' + +const ModelDownloadFugumtJaEnButton = (): JSX.Element => { + const modelType = "fugumt-ja-en" + const downloadedModels = useRecoilValue(modelFugumtJaEnDownloadedState) + const downloadedBaseModels = useRecoilValue(modelWhisperDownloadedState) + const [downloadingModels, setDownloadingModels] = useRecoilState(modelFugumtJaEnDownloadingState) + const click = () => { + setDownloadingModels([...downloadingModels, modelType]) + invoke('download_fugumt_jaen_model_command') + } + const is_downloaded = downloadedModels.filter(m => m === modelType).length > 0 + const is_downloading = downloadingModels.filter(m => m === modelType).length > 0 + const is_base_downloaded = downloadedBaseModels.filter(m => m === "large").length > 0 + + return ( + + ) +} + +export { ModelDownloadFugumtJaEnButton } \ No newline at end of file diff --git a/src/components/molecules/ModelDownloadFugumtJaEnProgress.tsx b/src/components/molecules/ModelDownloadFugumtJaEnProgress.tsx new file mode 100644 index 0000000..8984eb7 --- /dev/null +++ b/src/components/molecules/ModelDownloadFugumtJaEnProgress.tsx @@ -0,0 +1,42 @@ +import { useRecoilState, useSetRecoilState } from 'recoil' +import { modelFugumtJaEnDownloadingState } from '../../store/atoms/modelFugumtJaEnDownloadingState' +import { listen } from '@tauri-apps/api/event' +import { useEffect, useState } from 'react' +import { ProgressType } from '../../type/progress.type' +import { modelFugumtJaEnDownloadedState } from '../../store/atoms/modelFugumtJaEnDownloadedState' + +const ModelDownloadFugumtJaEnProgress = (): JSX.Element => { + const modelType = "fugumt-ja-en" + const setDownloadedModel = useSetRecoilState(modelFugumtJaEnDownloadedState) + const [downloadingModels, setDownloadingModels] = useRecoilState(modelFugumtJaEnDownloadingState) + const [progress, setProgress] = useState({ + model_type: modelType, + rate: 0, + is_progress: false + }) + useEffect(() => { + const unlisten = listen('downloadFugumtJaEnProgress', event => { + const p = event.payload as ProgressType + if (p.model_type === modelType) { + setProgress(p) + if (!p.is_progress) { + setDownloadingModels(prev => prev.filter(m => m !== modelType)) + setDownloadedModel(prev => [...prev, modelType]) + } + } + }) + return () => { + unlisten.then(f => f()); + } + }, []) + if (downloadingModels.filter(m => m === modelType).length > 0) { + return ( +
+
{progress.rate === 100 ? "解凍中" : `${progress.rate}%`}
+
+ ) + } + return (<>) +} + +export { ModelDownloadFugumtJaEnProgress } \ No newline at end of file diff --git a/src/components/molecules/ModelDownloadHonyaku13BButton.tsx b/src/components/molecules/ModelDownloadHonyaku13BButton.tsx index ddce6d1..89dbcd7 100644 --- a/src/components/molecules/ModelDownloadHonyaku13BButton.tsx +++ b/src/components/molecules/ModelDownloadHonyaku13BButton.tsx @@ -2,10 +2,12 @@ import { invoke } from '@tauri-apps/api/tauri' import { useRecoilState, useRecoilValue } from 'recoil' import { modelHonyaku13BDownloadingState } from '../../store/atoms/modelHonyaku13BDownloadingState' import { modelHonyaku13BDownloadedState } from '../../store/atoms/modelHonyaku13BDownloadedState' +import { modelWhisperDownloadedState } from '../../store/atoms/modelWhisperDownloadedState' const ModelDownloadHonyaku13BButton = (): JSX.Element => { - const modelType = "honyaku13b-q4-0" + const modelType = "honyaku-13b" const downloadedModels = useRecoilValue(modelHonyaku13BDownloadedState) + const downloadedBaseModels = useRecoilValue(modelWhisperDownloadedState) const [downloadingModels, setDownloadingModels] = useRecoilState(modelHonyaku13BDownloadingState) const click = () => { setDownloadingModels([...downloadingModels, modelType]) @@ -13,13 +15,14 @@ const ModelDownloadHonyaku13BButton = (): JSX.Element => { } const is_downloaded = downloadedModels.filter(m => m === modelType).length > 0 const is_downloading = downloadingModels.filter(m => m === modelType).length > 0 + const is_base_downloaded = downloadedBaseModels.filter(m => m === "large").length > 0 return ( - ) } diff --git a/src/components/molecules/ModelDownloadHonyaku13BProgress.tsx b/src/components/molecules/ModelDownloadHonyaku13BProgress.tsx index 54f6523..c1f2b98 100644 --- a/src/components/molecules/ModelDownloadHonyaku13BProgress.tsx +++ b/src/components/molecules/ModelDownloadHonyaku13BProgress.tsx @@ -6,7 +6,7 @@ import { ProgressType } from '../../type/progress.type' import { modelHonyaku13BDownloadedState } from '../../store/atoms/modelHonyaku13BDownloadedState' const ModelDownloadHonyaku13BProgress = (): JSX.Element => { - const modelType = "honyaku13b-q4-0" + const modelType = "honyaku-13b" const setDownloadedModel = useSetRecoilState(modelHonyaku13BDownloadedState) const [downloadingModels, setDownloadingModels] = useRecoilState(modelHonyaku13BDownloadingState) const [progress, setProgress] = useState({ diff --git a/src/components/molecules/ModelDownloadFugumtButton.tsx b/src/components/molecules/ModelDownloadStyleBertVits2Button.tsx similarity index 65% rename from src/components/molecules/ModelDownloadFugumtButton.tsx rename to src/components/molecules/ModelDownloadStyleBertVits2Button.tsx index 50bc0b2..17ea18a 100644 --- a/src/components/molecules/ModelDownloadFugumtButton.tsx +++ b/src/components/molecules/ModelDownloadStyleBertVits2Button.tsx @@ -1,15 +1,15 @@ import { invoke } from '@tauri-apps/api/tauri' import { useRecoilState, useRecoilValue } from 'recoil' -import { modelFugumtDownloadingState } from '../../store/atoms/modelFugumtDownloadingState' -import { modelFugumtDownloadedState } from '../../store/atoms/modelFugumtDownloadedState' +import { modelStyleBertVits2DownloadingState } from '../../store/atoms/modelStyleBertVits2DownloadingState' +import { modelStyleBertVits2DownloadedState } from '../../store/atoms/modelStyleBertVits2DownloadedState' -const ModelDownloadFugumtButton = (): JSX.Element => { - const modelType = "fugumt-en-ja" - const downloadedModels = useRecoilValue(modelFugumtDownloadedState) - const [downloadingModels, setDownloadingModels] = useRecoilState(modelFugumtDownloadingState) +const ModelDownloadStyleBertVits2Button = (): JSX.Element => { + const modelType = "style-bert-vits2" + const downloadedModels = useRecoilValue(modelStyleBertVits2DownloadedState) + const [downloadingModels, setDownloadingModels] = useRecoilState(modelStyleBertVits2DownloadingState) const click = () => { setDownloadingModels([...downloadingModels, modelType]) - invoke('download_fugumt_model_command') + invoke('download_sbv2_command') } const is_downloaded = downloadedModels.filter(m => m === modelType).length > 0 const is_downloading = downloadingModels.filter(m => m === modelType).length > 0 @@ -24,4 +24,4 @@ const ModelDownloadFugumtButton = (): JSX.Element => { ) } -export { ModelDownloadFugumtButton } \ No newline at end of file +export { ModelDownloadStyleBertVits2Button } \ No newline at end of file diff --git a/src/components/molecules/ModelDownloadStyleBertVits2Progress.tsx b/src/components/molecules/ModelDownloadStyleBertVits2Progress.tsx new file mode 100644 index 0000000..3e43361 --- /dev/null +++ b/src/components/molecules/ModelDownloadStyleBertVits2Progress.tsx @@ -0,0 +1,42 @@ +import { useRecoilState, useSetRecoilState } from 'recoil' +import { modelStyleBertVits2DownloadingState } from '../../store/atoms/modelStyleBertVits2DownloadingState' +import { listen } from '@tauri-apps/api/event' +import { useEffect, useState } from 'react' +import { ProgressType } from '../../type/progress.type' +import { modelStyleBertVits2DownloadedState } from '../../store/atoms/modelStyleBertVits2DownloadedState' + +const ModelDownloadStyleBertVits2Progress = (): JSX.Element => { + const modelType = "style-bert-vits2" + const setDownloadedModel = useSetRecoilState(modelStyleBertVits2DownloadedState) + const [downloadingModels, setDownloadingModels] = useRecoilState(modelStyleBertVits2DownloadingState) + const [progress, setProgress] = useState({ + model_type: modelType, + rate: 0, + is_progress: false + }) + useEffect(() => { + const unlisten = listen('downloadStyleBertVits2Progress', event => { + const p = event.payload as ProgressType + if (p.model_type === modelType) { + setProgress(p) + if (!p.is_progress) { + setDownloadingModels(prev => prev.filter(m => m !== modelType)) + setDownloadedModel(prev => [...prev, modelType]) + } + } + }) + return () => { + unlisten.then(f => f()); + } + }, []) + if (downloadingModels.filter(m => m === modelType).length > 0) { + return ( +
+
{progress.rate === 100 ? "解凍中" : `${progress.rate}%`}
+
+ ) + } + return (<>) +} + +export { ModelDownloadStyleBertVits2Progress } \ No newline at end of file diff --git a/src/components/molecules/ModelDownloadStyleBertVits2VoiceButton.tsx b/src/components/molecules/ModelDownloadStyleBertVits2VoiceButton.tsx new file mode 100644 index 0000000..9705460 --- /dev/null +++ b/src/components/molecules/ModelDownloadStyleBertVits2VoiceButton.tsx @@ -0,0 +1,34 @@ +import { invoke } from '@tauri-apps/api/tauri' +import { useRecoilState, useRecoilValue } from 'recoil' +import { modelStyleBertVits2VoiceDownloadingState } from '../../store/atoms/modelStyleBertVits2VoiceDownloadingState' +import { modelStyleBertVits2VoiceDownloadedState } from '../../store/atoms/modelStyleBertVits2VoiceDownloadedState' +import { modelStyleBertVits2DownloadedState } from '../../store/atoms/modelStyleBertVits2DownloadedState' + +type Props = { + modelType: string +} + +const ModelDownloadStyleBertVits2VoiceButton = (props: Props): JSX.Element => { + const { modelType } = props + const downloadedModels = useRecoilValue(modelStyleBertVits2VoiceDownloadedState); + const downloadedBaseModels = useRecoilValue(modelStyleBertVits2DownloadedState); + const [downloadingModels, setDownloadingModels] = useRecoilState(modelStyleBertVits2VoiceDownloadingState) + const click = () => { + setDownloadingModels([...downloadingModels, modelType]) + invoke('download_sbv2_model_command', { model: modelType }) + } + const is_downloaded = downloadedModels.filter(m => m === modelType).length > 0 + const is_downloading = downloadingModels.filter(m => m === modelType).length > 0 + const is_base_downloaded = downloadedBaseModels.filter(m => m === "style-bert-vits2").length > 0 + + return ( + + ) +} + +export { ModelDownloadStyleBertVits2VoiceButton } \ No newline at end of file diff --git a/src/components/molecules/ModelDownloadStyleBertVits2VoiceProgress.tsx b/src/components/molecules/ModelDownloadStyleBertVits2VoiceProgress.tsx new file mode 100644 index 0000000..ac1615c --- /dev/null +++ b/src/components/molecules/ModelDownloadStyleBertVits2VoiceProgress.tsx @@ -0,0 +1,46 @@ +import { useRecoilState, useSetRecoilState } from 'recoil' +import { listen } from '@tauri-apps/api/event' +import { useEffect, useState } from 'react' +import { ProgressType } from '../../type/progress.type' +import { modelStyleBertVits2VoiceDownloadingState } from '../../store/atoms/modelStyleBertVits2VoiceDownloadingState' +import { modelStyleBertVits2VoiceDownloadedState } from '../../store/atoms/modelStyleBertVits2VoiceDownloadedState' + +type Props = { + modelType: string +} + +const ModelDownloadStyleBertVits2VoiceProgress = (props: Props): JSX.Element => { + const { modelType } = props + const setDownloadedModel = useSetRecoilState(modelStyleBertVits2VoiceDownloadedState) + const [downloadingModels, setDownloadingModels] = useRecoilState(modelStyleBertVits2VoiceDownloadingState) + const [progress, setProgress] = useState({ + model_type: modelType, + rate: 0, + is_progress: false + }) + useEffect(() => { + const unlisten = listen('downloadStyleBertVits2VoiceProgress', event => { + const p = event.payload as ProgressType + if (p.model_type === modelType) { + setProgress(p) + if (!p.is_progress) { + setDownloadingModels(prev => prev.filter(m => m !== modelType)) + setDownloadedModel(prev => [...prev, modelType]) + } + } + }) + return () => { + unlisten.then(f => f()); + } + }, []) + if (downloadingModels.filter(m => m === modelType).length > 0) { + return ( +
+
{`${progress.rate}%`}
+
+ ) + } + return (<>) +} + +export { ModelDownloadStyleBertVits2VoiceProgress } \ No newline at end of file diff --git a/src/components/molecules/MyMarkdown.tsx b/src/components/molecules/MyMarkdown.tsx index c6bf412..329c2ae 100644 --- a/src/components/molecules/MyMarkdown.tsx +++ b/src/components/molecules/MyMarkdown.tsx @@ -25,6 +25,9 @@ const MyMarkdown = (props: MyMarkdownProps) => { const [contents, setContents] = useState([]); + const [isTableSelected, setIsTableSelected] = useState(false); + const [tableId, setTableId] = useState(0); + const [isTextSelected, setIsTextSelected] = useState(false); const [textSelected, setTextSelected] = useState(""); const handleMouseDown = (e: MouseEvent) => { @@ -49,18 +52,34 @@ const MyMarkdown = (props: MyMarkdownProps) => { clipboard.writeText(contents[elementId]); } const handleImage = async (type: "copy" | "download") => { - const target = rootRef.current?.querySelectorAll("pre code")[elementId] as HTMLElement; - const canvas = await html2canvas(target, - { - backgroundColor: null, - onclone: (_, element) => { - element.style.setProperty("overflow-x", "unset"); - element.style.setProperty("width", "fit-content"); - if (!target.className.includes("mermaid")) { - element.style.backgroundColor = "#1a2638"; - } - } - }); + const target = (() => { + if (isTableSelected) return rootRef.current?.querySelectorAll("table")[tableId] as HTMLElement; + return rootRef.current?.querySelectorAll("pre code")[elementId] as HTMLElement; + })(); + const canvas = await (async () => { + if (isTableSelected) { + return await html2canvas(target, + { + backgroundColor: null, + onclone: (_, element) => { + element.style.setProperty("overflow-x", "unset"); + element.style.setProperty("width", "fit-content"); + } + }); + } else { + return await html2canvas(target, + { + backgroundColor: null, + onclone: (_, element) => { + element.style.setProperty("overflow-x", "unset"); + if (!target.className.includes("mermaid")) { + element.style.backgroundColor = "#1a2638"; + element.style.width = "unset"; + } + } + }); + } + })(); if (type === "download") { const blob = await new Promise((resolve) => { canvas.toBlob((blob) => { @@ -90,11 +109,13 @@ const MyMarkdown = (props: MyMarkdownProps) => { block.classList.add("hover:border-base-300", "border-2", "border-transparent", "rounded-lg", "cursor-pointer"); } else { hljs.highlightBlock(block as HTMLElement); + block.classList.add("cursor-pointer", "w-full"); } const handleContextMenu = (e: MouseEvent) => { e.preventDefault(); setAnchorPoint({ x: e.clientX, y: e.clientY }); + setIsTableSelected(false); setElementId(index); setOpen(true); }; @@ -103,6 +124,21 @@ const MyMarkdown = (props: MyMarkdownProps) => { listeners.set(block, ['contextmenu', handleContextMenu]); }); + rootRef.current?.querySelectorAll('table').forEach(async (block, index) => { + block.classList.add("hover:border-base-300", "border-2", "border-transparent", "rounded-lg", "cursor-pointer", "!w-fit"); + + const handleContextMenu = (e: MouseEvent) => { + e.preventDefault(); + setAnchorPoint({ x: e.clientX, y: e.clientY }); + setIsTableSelected(true); + setTableId(index); + setOpen(true); + }; + + (block as HTMLElement).addEventListener('contextmenu', handleContextMenu); + listeners.set(block, ['contextmenu', handleContextMenu]); + }); + return () => { listeners.forEach(([event, listener], block) => { (block as HTMLElement).removeEventListener(event, listener); @@ -135,20 +171,31 @@ const MyMarkdown = (props: MyMarkdownProps) => {

コピー

: - <> - - -

全体をコピー

-
- handleImage("copy")}> - -

画像としてコピー

-
- handleImage("download")}> - -

画像としてダウンロード

-
- } + isTableSelected ? + <> + handleImage("copy")}> + +

画像としてコピー

+
+ handleImage("download")}> + +

画像としてダウンロード

+
+ : + <> + + +

全体をコピー

+
+ handleImage("copy")}> + +

画像としてコピー

+
+ handleImage("download")}> + +

画像としてダウンロード

+
+ }
) diff --git a/src/components/molecules/Screenshot.tsx b/src/components/molecules/Screenshot.tsx index 46d42c0..56c86f1 100644 --- a/src/components/molecules/Screenshot.tsx +++ b/src/components/molecules/Screenshot.tsx @@ -14,7 +14,12 @@ const Screenshot = (props: ScreenshotProps): JSX.Element => {
{date}
- screenshot + screenshot
diff --git a/src/components/molecules/SettingFCfunctionCall.tsx b/src/components/molecules/SettingFCfunctionCall.tsx index 36ff534..118fd5d 100644 --- a/src/components/molecules/SettingFCfunctionCall.tsx +++ b/src/components/molecules/SettingFCfunctionCall.tsx @@ -14,7 +14,7 @@ const SettingFCfunctionCall = (): JSX.Element => {

Function Calling
(function_call)

-

AIからの返答に利用する関数を選択

+

アシスタントからの返答に利用する関数を選択

無指定では、必要なときのみ関数が実行されます

必ず実行する場合は、関数名を指定してください

diff --git a/src/components/molecules/SettingFCfunctions.tsx b/src/components/molecules/SettingFCfunctions.tsx index 6bb8ecf..287b668 100644 --- a/src/components/molecules/SettingFCfunctions.tsx +++ b/src/components/molecules/SettingFCfunctions.tsx @@ -14,7 +14,7 @@ const SettingFCfunctions = (): JSX.Element => {

Function Calling
(functions)

-

AIからの返答に利用する関数一覧

+

アシスタントからの返答に利用する関数一覧