diff --git a/crates/server/src/main.rs b/crates/server/src/main.rs index 1603e15..a7b1083 100644 --- a/crates/server/src/main.rs +++ b/crates/server/src/main.rs @@ -1,17 +1,14 @@ use std::{ path::PathBuf, - pin::pin, sync::{Arc, LazyLock}, - time::{Duration, Instant}, }; use miniserve::{http::StatusCode, Content, Request, Response}; use serde::{Deserialize, Serialize}; -use tokio::{ - fs, join, - sync::{mpsc, oneshot}, - task::JoinSet, -}; +use stateful::StatefulThread; +use tokio::{fs, join, task::JoinSet}; + +mod stateful; async fn index(_req: Request) -> Response { let content = include_str!("../index.html").to_string(); @@ -42,55 +39,49 @@ async fn load_docs(paths: Vec) -> Vec { docs } -type Payload = (Arc>, oneshot::Sender>>); - -fn chatbot_thread() -> (mpsc::Sender, mpsc::Sender<()>) { - let (req_tx, mut req_rx) = mpsc::channel::(1024); - let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); - tokio::spawn(async move { - let mut chatbot = chatbot::Chatbot::new(vec![":-)".into(), "^^".into()]); - while let Some((messages, responder)) = req_rx.recv().await { - let doc_paths = chatbot.retrieval_documents(&messages); - let docs = load_docs(doc_paths).await; - let mut chat_fut = pin!(chatbot.query_chat(&messages, &docs)); - let mut cancel_fut = pin!(cancel_rx.recv()); - let start = Instant::now(); - loop { - let log_fut = tokio::time::sleep(Duration::from_secs(1)); - tokio::select! { - response = &mut chat_fut => { - responder.send(Some(response)).unwrap(); - break; - } - _ = &mut cancel_fut => { - responder.send(None).unwrap(); - break; - } - _ = log_fut => { - println!("Waiting for {} seconds", start.elapsed().as_secs()); - } - } - } - } - }); - (req_tx, cancel_tx) +struct LogFunction { + logger: chatbot::Logger, } -static CHATBOT_THREAD: LazyLock<(mpsc::Sender, mpsc::Sender<()>)> = - LazyLock::new(chatbot_thread); +impl stateful::StatefulFunction for LogFunction { + type Input = Arc>; + type Output = (); -async fn query_chat(messages: &Arc>) -> Option> { - let (tx, rx) = oneshot::channel(); - CHATBOT_THREAD - .0 - .send((Arc::clone(messages), tx)) - .await - .unwrap(); - rx.await.unwrap() + async fn call(&mut self, messages: Self::Input) -> Self::Output { + self.logger.append(messages.last().unwrap()); + self.logger.save().await.unwrap(); + } +} + +static LOG_THREAD: LazyLock> = LazyLock::new(|| { + StatefulThread::new(LogFunction { + logger: chatbot::Logger::default(), + }) +}); + +struct ChatbotFunction { + chatbot: chatbot::Chatbot, } +impl stateful::StatefulFunction for ChatbotFunction { + type Input = Arc>; + type Output = Vec; + + async fn call(&mut self, messages: Self::Input) -> Self::Output { + let doc_paths = self.chatbot.retrieval_documents(&messages); + let docs = load_docs(doc_paths).await; + self.chatbot.query_chat(&messages, &docs).await + } +} + +static CHATBOT_THREAD: LazyLock> = LazyLock::new(|| { + StatefulThread::new(ChatbotFunction { + chatbot: chatbot::Chatbot::new(vec![":-)".into(), "^^".into()]), + }) +}); + async fn cancel(_req: Request) -> Response { - CHATBOT_THREAD.1.send(()).await.unwrap(); + CHATBOT_THREAD.cancel().await; Ok(Content::Html("success".into())) } @@ -103,7 +94,11 @@ async fn chat(req: Request) -> Response { }; let messages = Arc::new(data.messages); - let (i, responses_opt) = join!(chatbot::gen_random_number(), query_chat(&messages)); + let (i, responses_opt, _) = join!( + chatbot::gen_random_number(), + CHATBOT_THREAD.call(Arc::clone(&messages)), + LOG_THREAD.call(Arc::clone(&messages)) + ); let response = match responses_opt { Some(mut responses) => { diff --git a/crates/server/src/stateful.rs b/crates/server/src/stateful.rs new file mode 100644 index 0000000..bd3a199 --- /dev/null +++ b/crates/server/src/stateful.rs @@ -0,0 +1,73 @@ +use std::{ + fmt::Debug, + future::Future, + pin::pin, + time::{Duration, Instant}, +}; + +use tokio::{ + sync::{mpsc, oneshot}, + task::JoinHandle, +}; + +pub trait StatefulFunction: Send + 'static { + type Input: Send; + type Output: Send + Debug; + fn call(&mut self, input: Self::Input) -> impl Future + Send; +} + +type Payload = ( + ::Input, + oneshot::Sender::Output>>, +); + +pub struct StatefulThread { + _handle: JoinHandle<()>, + input_tx: mpsc::Sender>, + cancel_tx: mpsc::Sender<()>, +} + +impl StatefulThread { + pub fn new(mut func: F) -> Self { + let (input_tx, mut input_rx) = mpsc::channel::>(1024); + let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); + let _handle = tokio::spawn(async move { + while let Some((input, responder)) = input_rx.recv().await { + let mut output_fut = pin!(func.call(input)); + let mut cancel_fut = pin!(cancel_rx.recv()); + let start = Instant::now(); + loop { + let log_fut = tokio::time::sleep(Duration::from_secs(1)); + tokio::select! { + response = &mut output_fut => { + responder.send(Some(response)).unwrap(); + break; + } + _ = &mut cancel_fut => { + responder.send(None).unwrap(); + break; + } + _ = log_fut => { + println!("Waiting for {} seconds", start.elapsed().as_secs()); + } + } + } + } + }); + StatefulThread { + _handle, + input_tx, + cancel_tx, + } + } + + pub async fn call(&self, input: F::Input) -> Option { + let (tx, rx) = oneshot::channel(); + self.input_tx.send((input, tx)).await.unwrap(); + rx.await.unwrap() + } + + pub async fn cancel(&self) { + self.cancel_tx.send(()).await.unwrap(); + } +}