From 360c7bfebe1b47c73b865bf82a56571cbcae8d46 Mon Sep 17 00:00:00 2001 From: Clark McCauley Date: Sat, 5 Aug 2023 10:02:00 -0600 Subject: [PATCH] Cancellation (#20) * wip * Blocking cancellation --- src-tauri/Cargo.lock | 74 ++++++++++++++++++++++++++++++----- src-tauri/Cargo.toml | 1 + src-tauri/src/cancellation.rs | 56 ++++++++++++++++++++++++++ src-tauri/src/main.rs | 56 ++++++++++++++++---------- src-tauri/src/models.rs | 15 ++----- src/Chat.jsx | 20 ++++++++-- src/Sidebar.jsx | 4 +- src/api.js | 4 ++ src/state/messagesSlice.js | 5 +++ src/utilities.js | 14 +++++++ 10 files changed, 202 insertions(+), 47 deletions(-) create mode 100644 src-tauri/src/cancellation.rs diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 7e40209..1524a83 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -527,6 +527,7 @@ dependencies = [ "anyhow", "bytesize", "cocoa", + "flume", "futures-util", "home", "html2text", @@ -1241,6 +1242,19 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "flume" +version = "0.10.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "pin-project", + "spin", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1514,14 +1528,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.0+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] name = "ggml" version = "0.2.0-dev" -source = "git+https://github.com/rustformers/llm?branch=main#3b114b00c8338e78bfd99992911812471f1cdab5" +source = "git+https://github.com/rustformers/llm?branch=main#39eb341aeda6a3ff0240421e54df2707ae8743fc" dependencies = [ "ggml-sys", "memmap2", @@ -1531,7 +1547,7 @@ dependencies = [ [[package]] name = "ggml-sys" version = "0.2.0-dev" -source = "git+https://github.com/rustformers/llm?branch=main#3b114b00c8338e78bfd99992911812471f1cdab5" +source = "git+https://github.com/rustformers/llm?branch=main#39eb341aeda6a3ff0240421e54df2707ae8743fc" dependencies = [ "cc", ] @@ -2212,7 +2228,7 @@ checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "llm" version = "0.2.0-dev" -source = "git+https://github.com/rustformers/llm?branch=main#3b114b00c8338e78bfd99992911812471f1cdab5" +source = "git+https://github.com/rustformers/llm?branch=main#39eb341aeda6a3ff0240421e54df2707ae8743fc" dependencies = [ "llm-base", "llm-bloom", @@ -2228,7 +2244,7 @@ dependencies = [ [[package]] name = "llm-base" version = "0.2.0-dev" -source = "git+https://github.com/rustformers/llm?branch=main#3b114b00c8338e78bfd99992911812471f1cdab5" +source = "git+https://github.com/rustformers/llm?branch=main#39eb341aeda6a3ff0240421e54df2707ae8743fc" dependencies = [ "bytemuck", "ggml", @@ -2247,7 +2263,7 @@ dependencies = [ [[package]] name = "llm-bloom" version = "0.2.0-dev" -source = "git+https://github.com/rustformers/llm?branch=main#3b114b00c8338e78bfd99992911812471f1cdab5" +source = "git+https://github.com/rustformers/llm?branch=main#39eb341aeda6a3ff0240421e54df2707ae8743fc" dependencies = [ "llm-base", ] @@ -2255,7 +2271,7 @@ dependencies = [ [[package]] name = "llm-gpt2" version = "0.2.0-dev" -source = "git+https://github.com/rustformers/llm?branch=main#3b114b00c8338e78bfd99992911812471f1cdab5" +source = "git+https://github.com/rustformers/llm?branch=main#39eb341aeda6a3ff0240421e54df2707ae8743fc" dependencies = [ "bytemuck", "llm-base", @@ -2264,7 +2280,7 @@ dependencies = [ [[package]] name = "llm-gptj" version = "0.2.0-dev" -source = "git+https://github.com/rustformers/llm?branch=main#3b114b00c8338e78bfd99992911812471f1cdab5" +source = "git+https://github.com/rustformers/llm?branch=main#39eb341aeda6a3ff0240421e54df2707ae8743fc" dependencies = [ "llm-base", ] @@ -2272,7 +2288,7 @@ dependencies = [ [[package]] name = "llm-gptneox" version = "0.2.0-dev" -source = "git+https://github.com/rustformers/llm?branch=main#3b114b00c8338e78bfd99992911812471f1cdab5" +source = "git+https://github.com/rustformers/llm?branch=main#39eb341aeda6a3ff0240421e54df2707ae8743fc" dependencies = [ "llm-base", ] @@ -2280,7 +2296,7 @@ dependencies = [ [[package]] name = "llm-llama" version = "0.2.0-dev" -source = "git+https://github.com/rustformers/llm?branch=main#3b114b00c8338e78bfd99992911812471f1cdab5" +source = "git+https://github.com/rustformers/llm?branch=main#39eb341aeda6a3ff0240421e54df2707ae8743fc" dependencies = [ "llm-base", "tracing", @@ -2289,7 +2305,7 @@ dependencies = [ [[package]] name = "llm-mpt" version = "0.2.0-dev" -source = "git+https://github.com/rustformers/llm?branch=main#3b114b00c8338e78bfd99992911812471f1cdab5" +source = "git+https://github.com/rustformers/llm?branch=main#39eb341aeda6a3ff0240421e54df2707ae8743fc" dependencies = [ "llm-base", ] @@ -2515,6 +2531,15 @@ dependencies = [ "syn 2.0.26", ] +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom 0.2.10", +] + [[package]] name = "native-tls" version = "0.2.11" @@ -3078,6 +3103,26 @@ dependencies = [ "siphasher", ] +[[package]] +name = "pin-project" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "030ad2bc4db10a8944cb0d837f158bdfec4d4a4873ab701a95046770d11f8842" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec2e072ecce94ec471b13398d5402c188e76ac03cf74dd1a975161b23a3f6d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.26", +] + [[package]] name = "pin-project-lite" version = "0.2.10" @@ -3929,6 +3974,15 @@ dependencies = [ "system-deps 5.0.0", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + [[package]] name = "spm_precompiled" version = "0.1.4" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 326a903..8376870 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -19,6 +19,7 @@ serde_json = "1.0" home = "0.5.5" rand = "0.8.5" lazy_static = "1.4.0" +flume = "0.10.14" anyhow = "1.0.71" reqwest = { version = "0.11.17", features = ["stream"] } futures-util = "0.3.28" diff --git a/src-tauri/src/cancellation.rs b/src-tauri/src/cancellation.rs new file mode 100644 index 0000000..705db67 --- /dev/null +++ b/src-tauri/src/cancellation.rs @@ -0,0 +1,56 @@ +use flume::{bounded, Receiver, SendError, Sender}; +use std::convert::Infallible; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Mutex; +use tracing::info; + +pub struct Canceller { + cancelled: AtomicBool, + rx: Mutex>, + tx: Sender<()>, +} + +impl Canceller { + #[tracing::instrument(skip(self))] + pub async fn cancel(&self) -> Result<(), SendError<()>> { + info!("cancelling inference"); + self.cancelled.store(true, Ordering::Release); + self.tx.send_async(()).await + } + + pub fn is_cancelled(&self) -> bool { + self.cancelled.load(Ordering::Acquire) + } + + pub fn reset(&self) { + self.cancelled.store(false, Ordering::Release); + } + + #[tracing::instrument(skip(self))] + pub fn inference_feedback(&self) -> Result { + // When a cancellation occurs, the sender will block until it is received at least + // once. We want to check and see if that message has been sent, and if so we'll cancel. + let cancelled = if let Ok(rx) = self.rx.try_lock() { + rx.try_recv().is_ok() + } else { + false + }; + if cancelled || self.is_cancelled() { + info!("sending halt"); + Ok(llm::InferenceFeedback::Halt) + } else { + Ok(llm::InferenceFeedback::Continue) + } + } +} + +impl Default for Canceller { + fn default() -> Self { + let (tx, rx) = bounded(0); + Self { + cancelled: Default::default(), + rx: Mutex::new(rx), + tx, + } + } +} diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 5da07c6..a684516 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -11,22 +11,22 @@ mod config; mod events; mod models; +mod cancellation; mod context_file; mod prompt; #[cfg(target_os = "macos")] mod titlebar; -#[cfg(target_os = "macos")] -use crate::titlebar::WindowExt; - +use crate::cancellation::Canceller; use crate::config::get_logs_dir; use crate::events::Event; use crate::models::{get_local_model, Architecture, Model, ModelManager}; use crate::prompt::Template; +#[cfg(target_os = "macos")] +use crate::titlebar::WindowExt; use bytesize::ByteSize; use llm::{InferenceResponse, LoadProgress}; use serde::Serialize; -use std::convert::Infallible; use std::fs; use std::fs::create_dir_all; use std::path::PathBuf; @@ -55,10 +55,16 @@ fn get_prompt_templates() -> Vec