From 3a8f1fcf21984976c0c98786b9f44ed4f6eea8e4 Mon Sep 17 00:00:00 2001 From: Andrew Moon Date: Sun, 15 Dec 2024 14:56:33 +0900 Subject: [PATCH] Transcription / Summary / Sentiment analysis with OpenAI (#33) * wip: openai whisper integration * feat: working implementation of transcription * feat: launch subtask for transcription * feat: summarize * feat: summarize and extract sentiment --- Recordiary/Recordiary/Models/Diary.swift | 4 + backend/Cargo.lock | 110 ++++++++++++++++++- backend/Cargo.toml | 2 + backend/src/db/diary.rs | 46 +++++++- backend/src/handlers/deco.rs | 2 +- backend/src/handlers/diary.rs | 47 +++++++- backend/src/lib.rs | 12 ++ backend/src/openai.rs | 2 + backend/src/openai/client.rs | 133 +++++++++++++++++++++++ backend/src/openai/diary.rs | 41 +++++++ backend/src/storage/client.rs | 12 +- 11 files changed, 391 insertions(+), 20 deletions(-) create mode 100644 backend/src/openai.rs create mode 100644 backend/src/openai/client.rs create mode 100644 backend/src/openai/diary.rs diff --git a/Recordiary/Recordiary/Models/Diary.swift b/Recordiary/Recordiary/Models/Diary.swift index e508bd0..dba7109 100644 --- a/Recordiary/Recordiary/Models/Diary.swift +++ b/Recordiary/Recordiary/Models/Diary.swift @@ -14,6 +14,8 @@ struct DiaryModel: Decodable { let userId: String let audioLink: String let summary: String? + let transcription: String? + let emotion: String? let isPrivate: Bool enum CodingKeys: String, CodingKey { @@ -23,6 +25,8 @@ struct DiaryModel: Decodable { case userId = "user_id" case audioLink = "audio_link" case summary + case transcription + case emotion case isPrivate = "is_private" } } diff --git a/backend/Cargo.lock b/backend/Cargo.lock index 8f36205..aa62e1a 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -321,6 +321,12 @@ dependencies = [ "typenum", ] +[[package]] +name = "data-encoding" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" + [[package]] name = "der" version = "0.7.9" @@ -518,6 +524,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -538,6 +555,7 @@ checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -1057,6 +1075,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1216,6 +1244,21 @@ version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "openai-api-rs" +version = "5.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0061068e3cd1d5a538a0c61484fb1e0722f5113d107e7e1c652b62a6fcba148" +dependencies = [ + "bytes", + "futures-util", + "reqwest", + "serde", + "serde_json", + "tokio", + "tokio-tungstenite", +] + [[package]] name = "openssl" version = "0.10.68" @@ -1426,10 +1469,12 @@ dependencies = [ "axum", "hyper", "itertools", + "openai-api-rs", "reqwest", "serde", "serde_json", "sqlx", + "tempfile", "time", "tokio", "tower-http", @@ -1514,6 +1559,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "once_cell", "percent-encoding", @@ -1526,6 +1572,7 @@ dependencies = [ "system-configuration", "tokio", "tokio-native-tls", + "tokio-socks", "tower-service", "url", "wasm-bindgen", @@ -2137,9 +2184,9 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" +checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" dependencies = [ "cfg-if", "fastrand", @@ -2284,6 +2331,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-socks" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d4770b8024672c1101b3f6733eab95b18007dbe0847a8afe341fcf79e06043f" +dependencies = [ + "either", + "futures-util", + "thiserror", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.16" @@ -2295,6 +2354,20 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" +dependencies = [ + "futures-util", + "log", + "native-tls", + "tokio", + "tokio-native-tls", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.12" @@ -2421,12 +2494,37 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "native-tls", + "rand", + "sha1", + "thiserror", + "utf-8", +] + [[package]] name = "typenum" version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "unicase" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e51b68083f157f853b6379db119d1c1be0e6e4dec98101079dec41f6f5cf6df" + [[package]] name = "unicode-bidi" version = "0.3.17" @@ -2477,6 +2575,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf16_iter" version = "1.0.5" diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 874df65..5b6e164 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -29,3 +29,5 @@ tower-http = { version = "0.6.1", features = ["limit", "trace"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } itertools = "0.13.0" +tempfile = "3.14.0" +openai-api-rs = "5.2.3" diff --git a/backend/src/db/diary.rs b/backend/src/db/diary.rs index 3454305..462b8f6 100644 --- a/backend/src/db/diary.rs +++ b/backend/src/db/diary.rs @@ -11,6 +11,8 @@ pub struct Diary { user_id: Uuid, audio_link: Option, summary: Option, + transcription: Option, + emotion: Option, is_private: bool, } @@ -94,25 +96,61 @@ pub async fn update_diary( id: i64, audio_link: Option, summary: Option, + transcription: Option, + emotion: Option, is_private: Option, ) -> anyhow::Result<()> { - if audio_link.is_none() && summary.is_none() && is_private.is_none() { + if audio_link.is_none() + && summary.is_none() + && transcription.is_none() + && is_private.is_none() + && emotion.is_none() + { return Ok(()); } let mut qry_builder: sqlx::QueryBuilder<'_, Postgres> = sqlx::query_builder::QueryBuilder::new("UPDATE diary SET "); let mut separated = qry_builder.separated(", "); + let mut first = true; + if let Some(audio_link) = audio_link { + if !first { + separated.push_unseparated(", "); + } separated.push_unseparated("audio_link = "); - separated.push_bind(audio_link); + separated.push_bind_unseparated(audio_link); + first = false; } if let Some(summary) = summary { + if !first { + separated.push_unseparated(", "); + } separated.push_unseparated("summary = "); - separated.push_bind(summary); + separated.push_bind_unseparated(summary); + first = false; + } + if let Some(transcription) = transcription { + if !first { + separated.push_unseparated(", "); + } + separated.push_unseparated("transcription = "); + separated.push_bind_unseparated(transcription); + first = false; + } + if let Some(emotion) = emotion { + if !first { + separated.push_unseparated(", "); + } + separated.push_unseparated("emotion = "); + separated.push_bind_unseparated(emotion); + first = false; } if let Some(is_private) = is_private { + if !first { + separated.push_unseparated(", "); + } separated.push_unseparated("is_private = "); - separated.push_bind(is_private); + separated.push_bind_unseparated(is_private); } qry_builder.push(" WHERE id = ").push_bind(id); diff --git a/backend/src/handlers/deco.rs b/backend/src/handlers/deco.rs index 8212b98..dd25bbe 100644 --- a/backend/src/handlers/deco.rs +++ b/backend/src/handlers/deco.rs @@ -95,7 +95,7 @@ pub async fn create_deco( }; // upload model to storage let url = storage_client - .upload_model(model_bytes.to_vec(), params.name.clone()) + .upload_model(model_bytes.to_vec(), ¶ms.name) .await?; let deco = crate::db::deco::create_deco( diff --git a/backend/src/handlers/diary.rs b/backend/src/handlers/diary.rs index ab6a1f9..8324c2f 100644 --- a/backend/src/handlers/diary.rs +++ b/backend/src/handlers/diary.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use axum::{ debug_handler, extract::{Multipart, Query, State}, @@ -10,6 +12,7 @@ use uuid::Uuid; use crate::{ db::diary::{insert_diary, update_diary, Diary}, + openai::{client::OpenAIClient, diary::summarize_diary}, storage::client::SupabaseClient, utils::{get_diary_filename, parse_multipart::parse_multipart, sqlx::get_pg_tx}, AppState, @@ -56,11 +59,12 @@ pub struct CreateDiaryParams { #[debug_handler(state = AppState)] pub async fn create_diary( State(pool): State, + State(openai_client): State>, State(storage_client): State, Query(params): Query, multipart: Multipart, ) -> axum::response::Result { - let mut tx = get_pg_tx(pool).await?; + let mut tx = get_pg_tx(pool.clone()).await?; let result: anyhow::Result<_> = async { let (audio_bytes, _audio_metadata) = match parse_multipart(multipart).await { Ok(audio_data) => audio_data, @@ -77,14 +81,45 @@ pub async fn create_diary( ), ) .await?; + let audio_title = get_diary_filename(params.user_id, diary_id); let audio_link = storage_client - .upload_diary( - audio_bytes.to_vec(), - get_diary_filename(params.user_id, diary_id), - ) + .upload_diary(audio_bytes.to_vec(), &audio_title) .await?; + update_diary(&mut tx, diary_id, Some(audio_link), None, None, None, None).await?; + + // create a background subtask to transcribe the audio + tokio::spawn(async move { + let mut tx = get_pg_tx(pool.clone()).await.unwrap(); + match openai_client.transcribe(&audio_title, &audio_bytes).await { + Ok(audio_transcription) => { + if let Err(e) = update_diary( + &mut tx, + diary_id, + None, + None, + Some(audio_transcription.clone()), + None, + None, + ) + .await + { + tracing::error!("Failed to update diary with transcription: {}", e); + } else { + if let Err(e) = tx.commit().await { + tracing::error!("Failed to commit transaction: {}", e); + } + tokio::spawn(async move { + summarize_diary(pool, openai_client, diary_id, audio_transcription) + .await; + }); + } + } + Err(e) => { + tracing::error!("Failed to transcribe audio: {}", e); + } + } + }); - update_diary(&mut tx, diary_id, Some(audio_link), None, None).await?; Ok(diary_id) } .await; diff --git a/backend/src/lib.rs b/backend/src/lib.rs index cad7e9f..fcb9675 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -1,9 +1,13 @@ +use std::sync::Arc; + use axum::extract::FromRef; use db::conn::initialize_conn_pool; +use openai::client::OpenAIClient; use storage::client::SupabaseClient; pub mod db; pub mod handlers; +pub mod openai; pub mod storage; pub mod utils; @@ -11,6 +15,7 @@ pub mod utils; pub struct AppState { pool: sqlx::PgPool, storage_client: SupabaseClient, + openai_client: Arc, } impl AppState { @@ -23,6 +28,7 @@ impl AppState { std::env::var("SUPABASE_AUDIO_BUCKET").unwrap(), std::env::var("SUPABASE_MODEL_BUCKET").unwrap(), ), + openai_client: Arc::new(OpenAIClient::new()), } } } @@ -38,3 +44,9 @@ impl FromRef for SupabaseClient { state.storage_client.clone() } } + +impl FromRef for Arc { + fn from_ref(state: &AppState) -> Arc { + state.openai_client.clone() + } +} diff --git a/backend/src/openai.rs b/backend/src/openai.rs new file mode 100644 index 0000000..2ebcdbf --- /dev/null +++ b/backend/src/openai.rs @@ -0,0 +1,2 @@ +pub mod client; +pub mod diary; diff --git a/backend/src/openai/client.rs b/backend/src/openai/client.rs new file mode 100644 index 0000000..53e12d2 --- /dev/null +++ b/backend/src/openai/client.rs @@ -0,0 +1,133 @@ +use std::{fs::File, io::Write}; + +use openai_api_rs::v1::{ + audio::AudioTranscriptionRequest, + chat_completion::{ChatCompletionMessage, ChatCompletionRequest, Content, MessageRole}, +}; + +pub enum Emotion { + Anger, + Sadness, + Happiness, + Neutral, +} + +pub struct OpenAIClient { + openai: openai_api_rs::v1::api::OpenAIClient, +} + +impl OpenAIClient { + pub fn new() -> Self { + Self { + openai: openai_api_rs::v1::api::OpenAIClient::builder() + .with_api_key(std::env::var("OPENAI_API_KEY").unwrap()) + .build() + .expect("Failed to create OpenAI client"), + } + } + + pub async fn transcribe( + &self, + audio_title: &str, + audio_content: &[u8], + ) -> anyhow::Result { + let tmp_dir = tempfile::tempdir()?; // the directory will be dropped with the lifetime + let tmp_path = tmp_dir.path().join(audio_title); + let mut audio_file = File::create(tmp_path.clone())?; + + audio_file.write_all(audio_content)?; + audio_file.flush()?; + let request = AudioTranscriptionRequest { + file: tmp_path.to_string_lossy().to_string(), + model: "whisper-1".to_string(), + prompt: None, + response_format: Some("json".to_string()), + temperature: None, + language: Some("ko".to_string()), + }; + let resp = self.openai.audio_transcription(request).await?; + Ok(resp.text) + } + + pub async fn summarize(&self, content: &str) -> anyhow::Result { + let request = ChatCompletionRequest { + model: "gpt-3.5-turbo".to_string(), + messages: vec![ChatCompletionMessage { + role: MessageRole::system, + content: Content::Text( + "Summarize the following text in korean: ".to_string() + content, + ), + name: None, + tool_call_id: None, + tool_calls: None, + }], + temperature: Some(0.0), + top_p: None, + n: None, + response_format: None, + stream: None, + stop: None, + max_tokens: Some(50), + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + seed: None, + tools: None, + parallel_tool_calls: None, + tool_choice: None, + }; + let resp = self.openai.chat_completion(request).await?; + Ok(resp.choices[0] + .message + .content + .clone() + .unwrap_or("".to_string()) + .to_string()) + } + + pub async fn sentiment(&self, content: &str) -> anyhow::Result { + let base_prompt = "Analyze the sentiment of the following korean text. + Only respond with one of the following choices: anger, sadness, happiness, neutral. + Target text: "; + let request = ChatCompletionRequest { + model: "gpt-3.5-turbo".to_string(), + messages: vec![ChatCompletionMessage { + role: MessageRole::system, + content: Content::Text(base_prompt.to_string() + content), + name: None, + tool_call_id: None, + tool_calls: None, + }], + temperature: Some(0.0), + top_p: None, + n: None, + response_format: None, + stream: None, + stop: None, + max_tokens: Some(50), + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + seed: None, + tools: None, + parallel_tool_calls: None, + tool_choice: None, + }; + let resp = self.openai.chat_completion(request).await?; + let raw_emotion = resp.choices[0] + .message + .content + .clone() + .unwrap_or("neutral".to_string()) + .to_string(); + Ok(raw_emotion) + } +} + +impl Default for OpenAIClient { + fn default() -> Self { + Self::new() + } +} diff --git a/backend/src/openai/diary.rs b/backend/src/openai/diary.rs new file mode 100644 index 0000000..0250c51 --- /dev/null +++ b/backend/src/openai/diary.rs @@ -0,0 +1,41 @@ +use std::sync::Arc; + +use sqlx::PgPool; + +use crate::{db::diary::update_diary, utils::sqlx::get_pg_tx}; + +use super::client::OpenAIClient; + +pub async fn summarize_diary( + pool: PgPool, + client: Arc, + diary_id: i64, + transcription: String, +) { + let mut tx = get_pg_tx(pool).await.unwrap(); + + let summary = client.summarize(&transcription).await.unwrap(); + let emotion = client.sentiment(&transcription).await.unwrap(); + let res = update_diary( + &mut tx, + diary_id, + None, + Some(summary), + None, + Some(emotion), + None, + ) + .await; + + let tx_res = match res { + Ok(_) => tx.commit().await, + Err(e) => { + tracing::error!("Summarize error: {:?}", e); + tx.rollback().await + } + }; + + if let Err(e) = tx_res { + tracing::error!("Summarize tx error: {:?}", e); + } +} diff --git a/backend/src/storage/client.rs b/backend/src/storage/client.rs index 86428b5..ece7ba3 100644 --- a/backend/src/storage/client.rs +++ b/backend/src/storage/client.rs @@ -46,9 +46,9 @@ impl SupabaseClient { pub async fn upload_diary( &self, audio: Vec, - filename: String, + filename: &str, ) -> Result { - self.upload(self.audio_bucket.clone(), filename.clone(), audio) + self.upload(self.audio_bucket.clone(), filename, audio) .await?; let presigned_suffix = self .get_presigned_download_url(self.audio_bucket.clone(), filename) @@ -63,9 +63,9 @@ impl SupabaseClient { pub async fn upload_model( &self, model: Vec, - filename: String, + filename: &str, ) -> Result { - self.upload(self.model_bucket.clone(), filename.clone(), model) + self.upload(self.model_bucket.clone(), filename, model) .await?; let presigned_suffix = self .get_presigned_download_url(self.model_bucket.clone(), filename) @@ -80,7 +80,7 @@ impl SupabaseClient { pub async fn upload( &self, bucket: String, - filename: String, + filename: &str, file: Vec, ) -> Result<(), ReqwestError> { let url: String = format!( @@ -108,7 +108,7 @@ impl SupabaseClient { pub async fn get_presigned_download_url( &self, bucket: String, - filename: String, + filename: &str, ) -> Result { let url = format!( "{}/storage/v1/object/sign/{}/{}",