Skip to content

Commit

Permalink
Cancellation (#20)
Browse files Browse the repository at this point in the history
* wip

* Blocking cancellation
  • Loading branch information
clarkmcc authored Aug 5, 2023
1 parent b2c70f5 commit 360c7bf
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 47 deletions.
74 changes: 64 additions & 10 deletions src-tauri/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ serde_json = "1.0"
home = "0.5.5"
rand = "0.8.5"
lazy_static = "1.4.0"
flume = "0.10.14"
anyhow = "1.0.71"
reqwest = { version = "0.11.17", features = ["stream"] }
futures-util = "0.3.28"
Expand Down
56 changes: 56 additions & 0 deletions src-tauri/src/cancellation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use flume::{bounded, Receiver, SendError, Sender};
use std::convert::Infallible;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use tracing::info;

pub struct Canceller {
cancelled: AtomicBool,
rx: Mutex<Receiver<()>>,
tx: Sender<()>,
}

impl Canceller {
#[tracing::instrument(skip(self))]
pub async fn cancel(&self) -> Result<(), SendError<()>> {
info!("cancelling inference");
self.cancelled.store(true, Ordering::Release);
self.tx.send_async(()).await
}

pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::Acquire)
}

pub fn reset(&self) {
self.cancelled.store(false, Ordering::Release);
}

#[tracing::instrument(skip(self))]
pub fn inference_feedback(&self) -> Result<llm::InferenceFeedback, Infallible> {
// When a cancellation occurs, the sender will block until it is received at least
// once. We want to check and see if that message has been sent, and if so we'll cancel.
let cancelled = if let Ok(rx) = self.rx.try_lock() {
rx.try_recv().is_ok()
} else {
false
};
if cancelled || self.is_cancelled() {
info!("sending halt");
Ok(llm::InferenceFeedback::Halt)
} else {
Ok(llm::InferenceFeedback::Continue)
}
}
}

impl Default for Canceller {
fn default() -> Self {
let (tx, rx) = bounded(0);
Self {
cancelled: Default::default(),
rx: Mutex::new(rx),
tx,
}
}
}
56 changes: 36 additions & 20 deletions src-tauri/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@ mod config;
mod events;
mod models;

mod cancellation;
mod context_file;
mod prompt;
#[cfg(target_os = "macos")]
mod titlebar;

#[cfg(target_os = "macos")]
use crate::titlebar::WindowExt;

use crate::cancellation::Canceller;
use crate::config::get_logs_dir;
use crate::events::Event;
use crate::models::{get_local_model, Architecture, Model, ModelManager};
use crate::prompt::Template;
#[cfg(target_os = "macos")]
use crate::titlebar::WindowExt;
use bytesize::ByteSize;
use llm::{InferenceResponse, LoadProgress};
use serde::Serialize;
use std::convert::Infallible;
use std::fs;
use std::fs::create_dir_all;
use std::path::PathBuf;
Expand Down Expand Up @@ -55,18 +55,25 @@ fn get_prompt_templates() -> Vec<Template> {
prompt::AVAILABLE_TEMPLATES.clone()
}

#[tauri::command]
async fn cancel(canceller: tauri::State<'_, Canceller>) -> Result<(), String> {
canceller.cancel().await.map_err(|err| err.to_string())
}

#[tauri::command]
async fn start(
window: Window,
state: tauri::State<'_, ManagerState>,
canceller: tauri::State<'_, Canceller>,
model_filename: String,
architecture: String,
tokenizer: String,
context_size: usize,
use_gpu: bool,
prompt: Template,
context_files: Vec<String>,
) -> Result<(), String> {
) -> Result<bool, String> {
canceller.reset();
let context = context_files
.iter()
.map(|path| context_file::read(PathBuf::from(path)))
Expand Down Expand Up @@ -175,28 +182,30 @@ async fn start(
progress,
}
.send(&window);
Ok::<llm::InferenceFeedback, Infallible>(llm::InferenceFeedback::Continue)
canceller.inference_feedback()
}
_ => Ok(llm::InferenceFeedback::Continue),
_ => canceller.inference_feedback(),
}),
)
.map_err(|e| format!("Error feeding prompt: {}", e))?;
Event::ModelLoading {
message: "Model loaded".to_string(),
progress: 1.0,
}
.send(&window);

info!("finished warm-up prompt");
if canceller.is_cancelled() {
return Ok(false);
}

info!("finished warm-up prompt");
*state.0.lock().unwrap() = Some(ModelManager {
model,
session,
template: prompt,
});

Event::ModelLoading {
message: "Model loaded".to_string(),
progress: 1.0,
}
.send(&window);

Ok(())
Ok(true)
}

#[derive(Serialize)]
Expand All @@ -205,11 +214,12 @@ pub struct PromptResponse {
pub message: String,
}

#[tracing::instrument(skip(window, state, message))]
#[tracing::instrument(skip(window, state, canceller, message))]
#[tauri::command]
async fn prompt(
window: Window,
state: tauri::State<'_, ManagerState>,
canceller: tauri::State<'_, Canceller>,
message: String,
) -> Result<PromptResponse, String> {
info!("received prompt");
Expand All @@ -221,9 +231,13 @@ async fn prompt(
let manager: &mut ModelManager = (*binding).as_mut().ok_or("Model not started".to_string())?;
let mut response = String::new();

let stats = manager.infer(&message, |tokens| {
response.push_str(&tokens);
Event::PromptResponse { message: tokens }.send(&window);
let stats = manager.infer(&message, |res| match res {
InferenceResponse::InferredToken(tokens) => {
response.push_str(&tokens);
Event::PromptResponse { message: tokens }.send(&window);
canceller.inference_feedback()
}
_ => canceller.inference_feedback(),
})?;

info!("finished prompt response");
Expand Down Expand Up @@ -268,8 +282,10 @@ fn main() {
get_architectures,
get_prompt_templates,
prompt,
cancel,
])
.manage(ManagerState(Mutex::new(None)));
.manage(ManagerState(Mutex::new(None)))
.manage(Canceller::default());

// #[cfg(feature = "analytics")]
// let panic_hook = tauri_plugin_aptabase::Builder::new(env!("APTABASE_KEY"))
Expand Down
Loading

0 comments on commit 360c7bf

Please sign in to comment.