Skip to content

Commit

Permalink
feat: Message metadata (#5)
Browse files Browse the repository at this point in the history
* chore: add metadata

* chore: v2 stream message

* chore: fix test
  • Loading branch information
appflowy authored Aug 12, 2024
1 parent 3462f22 commit 6f064ef
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 28 deletions.
66 changes: 59 additions & 7 deletions appflowy-local-ai/src/ai_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Weak;
use tokio_stream::wrappers::ReceiverStream;
use tracing::instrument;
use tracing::{instrument, trace};

pub struct AIPluginOperation {
plugin: Weak<Plugin>,
Expand Down Expand Up @@ -42,11 +42,11 @@ impl AIPluginOperation {
plugin.async_request::<T>("handle", &request).await
}

pub async fn create_chat(&self, chat_id: &str, rag_enabled: bool) -> Result<(), PluginError> {
pub async fn create_chat(&self, chat_id: &str) -> Result<(), PluginError> {
self
.send_request::<DefaultResponseParser>(
"create_chat",
json!({ "chat_id": chat_id, "rag_enabled": rag_enabled }),
json!({ "chat_id": chat_id, "top_k": 2}),
)
.await
}
Expand Down Expand Up @@ -76,16 +76,31 @@ impl AIPluginOperation {
&self,
chat_id: &str,
message: &str,
rag_enabled: bool,
metadata: serde_json::Value,
) -> Result<ReceiverStream<Result<Bytes, PluginError>>, PluginError> {
let plugin = self.get_plugin()?;
let params = json!({
"chat_id": chat_id,
"method": "stream_answer",
"params": { "content": message, "rag_enabled": rag_enabled }
"params": { "content": message, "metadata": metadata }
});
plugin.stream_request::<ChatStreamResponseParser>("handle", &params)
}
#[instrument(level = "debug", skip(self), err)]
pub async fn stream_message_v2(
&self,
chat_id: &str,
message: &str,
metadata: serde_json::Value,
) -> Result<ReceiverStream<Result<serde_json::Value, PluginError>>, PluginError> {
let plugin = self.get_plugin()?;
let params = json!({
"chat_id": chat_id,
"method": "stream_answer_v2",
"params": { "content": message, "metadata": metadata }
});
plugin.stream_request::<ChatStreamResponseV2Parser>("handle", &params)
}

pub async fn get_related_questions(&self, chat_id: &str) -> Result<Vec<String>, PluginError> {
self
Expand All @@ -96,8 +111,33 @@ impl AIPluginOperation {
.await
}

pub async fn index_file(&self, chat_id: &str, file_path: &str) -> Result<(), PluginError> {
let params = json!({ "file_path": file_path, "metadatas": [{"chat_id": chat_id}] });
#[instrument(level = "debug", skip_all, err)]
pub async fn index_file(
&self,
chat_id: &str,
file_path: Option<String>,
file_content: Option<String>,
metadata: Option<HashMap<String, serde_json::Value>>,
) -> Result<(), PluginError> {
if file_path.is_none() && file_content.is_none() {
return Err(PluginError::Internal(anyhow!(
"file_path or content must be provided"
)));
}

let mut metadata = metadata.unwrap_or_default();
metadata.insert("chat_id".to_string(), json!(chat_id));
let mut params = json!({ "metadata": [metadata] });

if let Some(file_path) = file_path {
params["file_path"] = json!(file_path);
}

if let Some(content) = file_content {
params["file_content"] = json!(content);
}

trace!("[AI Plugin] indexing file: {:?}", params);
self
.send_request::<DefaultResponseParser>(
"index_file",
Expand Down Expand Up @@ -184,6 +224,18 @@ impl ResponseParser for ChatStreamResponseParser {
}
}

pub struct ChatStreamResponseV2Parser;
impl ResponseParser for ChatStreamResponseV2Parser {
type ValueType = serde_json::Value;

fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
json
.as_str()
.and_then(|s| serde_json::from_str(s).ok())
.ok_or(RemoteError::ParseResponse(json))
}
}

pub struct ChatRelatedQuestionsResponseParser;
impl ResponseParser for ChatRelatedQuestionsResponseParser {
type ValueType = Vec<String>;
Expand Down
53 changes: 37 additions & 16 deletions appflowy-local-ai/src/chat_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use appflowy_plugin::manager::PluginManager;
use appflowy_plugin::util::{get_operating_system, OperatingSystem};
use bytes::Bytes;
use serde::{Deserialize, Serialize};

use serde_json::Value;
use std::collections::HashMap;
use std::fmt::Debug;
use std::path::PathBuf;
Expand Down Expand Up @@ -72,7 +74,7 @@ impl AppFlowyLocalAI {

let plugin = self.get_ai_plugin().await?;
let operation = AIPluginOperation::new(plugin);
operation.create_chat(chat_id, true).await?;
operation.create_chat(chat_id).await?;
Ok(())
}

Expand Down Expand Up @@ -115,12 +117,15 @@ impl AppFlowyLocalAI {
&self,
chat_id: &str,
message: &str,
) -> Result<ReceiverStream<anyhow::Result<Bytes, PluginError>>, PluginError> {
metadata: serde_json::Value,
) -> Result<ReceiverStream<anyhow::Result<Value, PluginError>>, PluginError> {
trace!("[AI Plugin] ask question: {}", message);
self.wait_until_plugin_ready().await?;
let plugin = self.get_ai_plugin().await?;
let operation = AIPluginOperation::new(plugin);
let stream = operation.stream_message(chat_id, message, true).await?;
let stream = operation
.stream_message_v2(chat_id, message, metadata)
.await?;
Ok(stream)
}

Expand All @@ -132,24 +137,40 @@ impl AppFlowyLocalAI {
Ok(values)
}

pub async fn index_file(&self, chat_id: &str, file_path: PathBuf) -> Result<(), PluginError> {
if !file_path.exists() {
return Err(PluginError::Io(io::Error::new(
io::ErrorKind::NotFound,
"file not found",
)));
}
pub async fn index_file(
&self,
chat_id: &str,
file_path: Option<PathBuf>,
file_content: Option<String>,
metadata: Option<HashMap<String, serde_json::Value>>,
) -> Result<(), PluginError> {
let mut file_path_str = None;
if let Some(file_path) = file_path {
if !file_path.exists() {
return Err(PluginError::Io(io::Error::new(
io::ErrorKind::NotFound,
"file not found",
)));
}

let file_path = file_path.to_str().ok_or(PluginError::Io(io::Error::new(
io::ErrorKind::NotFound,
"file path invalid",
)))?;
file_path_str = Some(
file_path
.to_str()
.ok_or(PluginError::Io(io::Error::new(
io::ErrorKind::NotFound,
"file path invalid",
)))?
.to_string(),
);
}

self.wait_until_plugin_ready().await?;
let plugin = self.get_ai_plugin().await?;
let operation = AIPluginOperation::new(plugin);
trace!("[AI Plugin] indexing file: {}", file_path);
operation.index_file(chat_id, file_path).await?;

operation
.index_file(chat_id, file_path_str, file_content, metadata)
.await?;
Ok(())
}

Expand Down
12 changes: 10 additions & 2 deletions appflowy-local-ai/tests/chat_test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::collections::HashMap;

use appflowy_local_ai::ai_ops::{CompleteTextType, LocalAITranslateItem, LocalAITranslateRowData};
use appflowy_plugin::manager::PluginManager;
use serde_json::Value;
use std::env::temp_dir;
use std::path::PathBuf;
use std::sync::Arc;
Expand Down Expand Up @@ -51,7 +52,10 @@ async fn ci_chat_stream_test() {
let mut resp = test.stream_chat_message(&chat_id, "what is banana?").await;
let mut list = vec![];
while let Some(s) = resp.next().await {
list.push(String::from_utf8(s.unwrap().to_vec()).unwrap());
if let Value::Object(mut map) = s.unwrap() {
let s = map.remove("1").unwrap().as_str().unwrap().to_string();
list.push(s);
}
}

let answer = list.join("");
Expand Down Expand Up @@ -111,7 +115,11 @@ async fn ci_chat_with_pdf() {
test.init_embedding_plugin().await;
let chat_id = uuid::Uuid::new_v4().to_string();
let pdf = get_asset_path("AppFlowy_Values.pdf");
test.local_ai.index_file(&chat_id, pdf).await.unwrap();
test
.local_ai
.index_file(&chat_id, Some(pdf), None, None)
.await
.unwrap();

let resp = test
.local_ai
Expand Down
7 changes: 4 additions & 3 deletions appflowy-local-ai/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use appflowy_local_ai::chat_plugin::{AIPluginConfig, AppFlowyLocalAI};
use appflowy_local_ai::embedding_plugin::{EmbeddingPluginConfig, LocalEmbedding};
use appflowy_plugin::error::PluginError;
use appflowy_plugin::manager::PluginManager;
use bytes::Bytes;

use serde_json::{json, Value};
use simsimd::SpatialSimilarity;
use std::f64;
use std::path::{Path, PathBuf};
Expand Down Expand Up @@ -75,10 +76,10 @@ impl LocalAITest {
&self,
chat_id: &str,
message: &str,
) -> ReceiverStream<Result<Bytes, PluginError>> {
) -> ReceiverStream<Result<Value, PluginError>> {
self
.local_ai
.stream_question(chat_id, message)
.stream_question(chat_id, message, json!([]))
.await
.unwrap()
}
Expand Down

0 comments on commit 6f064ef

Please sign in to comment.