From 111c38237a737da4d69493f4013510d97d1fb5af Mon Sep 17 00:00:00 2001 From: Kodai Aoyama Date: Thu, 4 Apr 2024 20:14:08 +0900 Subject: [PATCH] fix --- src-tauri/src/main.rs | 81 ++++++++--------- src-tauri/src/module/device.rs | 10 ++- src-tauri/src/module/permissions.rs | 11 +-- src-tauri/src/module/recognizer.rs | 2 +- src-tauri/src/module/record.rs | 30 ++----- src-tauri/src/module/record_desktop.rs | 90 +++++++++---------- src-tauri/src/module/transcription.rs | 36 ++++++-- src/components/molecules/SpeakerLanguage.tsx | 4 +- .../molecules/TranscriptionAccuracy.tsx | 4 +- 9 files changed, 138 insertions(+), 130 deletions(-) diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index be03186..73233bb 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -3,25 +3,37 @@ windows_subsystem = "windows" )] -use crossbeam_channel::Sender; -use module::model_type_vosk::ModelTypeVosk; -use module::model_type_whisper::ModelTypeWhisper; -use module::transcription::TraceCompletion; -use tauri::http::{HttpRange, ResponseBuilder}; -use tauri::{Manager, State, Window}; +use tauri::{ + http::{HttpRange, ResponseBuilder}, + Manager, State, Window, +}; use tauri_plugin_sql::{Migration, MigrationKind}; -use crossbeam_channel::unbounded; -use std::cmp::min; -use std::io::{Read, Seek, SeekFrom}; -use std::path::PathBuf; -use std::str::FromStr; -use std::sync::{Arc, Mutex}; +use std::{ + cmp::min, + io::{Read, Seek, SeekFrom}, + path::PathBuf, + str::FromStr, + sync::{Arc, Mutex}, +}; +use crossbeam_channel::{unbounded, Sender}; use urlencoding::decode; mod module; -use module::device::Device; +use module::{ + chat_online::ChatOnline, + deleter::NoteDeleter, + device::{self, Device}, + downloader::{vosk::VoskModelDownloader, whisper::WhisperModelDownloader}, + model_type_vosk::ModelTypeVosk, + model_type_whisper::ModelTypeWhisper, + permissions, + record::Record, + record_desktop::RecordDesktop, + transcription::{TraceCompletion, Transcription}, + transcription_online::TranscriptionOnline, +}; struct RecordState(Arc>>>); @@ -30,7 +42,7 @@ const BUNDLE_IDENTIFIER: &str = "blog.aota.Lycoris"; #[tauri::command] fn delete_note_command(window: Window, note_id: u64) { std::thread::spawn(move || { - let deleter = module::deleter::NoteDeleter::new(window.app_handle().clone()); + let deleter = NoteDeleter::new(window.app_handle().clone()); deleter.delete(note_id) }); } @@ -38,8 +50,7 @@ fn delete_note_command(window: Window, note_id: u64) { #[tauri::command] fn download_whisper_model_command(window: Window, model: String) { std::thread::spawn(move || { - let dl = - module::downloader::whisper::WhisperModelDownloader::new(window.app_handle().clone()); + let dl = WhisperModelDownloader::new(window.app_handle().clone()); dl.download(ModelTypeWhisper::from_str(&model).unwrap()) }); } @@ -47,29 +58,29 @@ fn download_whisper_model_command(window: Window, model: String) { #[tauri::command] fn download_vosk_model_command(window: Window, model: String) { std::thread::spawn(move || { - let dl = module::downloader::vosk::VoskModelDownloader::new(window.app_handle().clone()); + let dl = VoskModelDownloader::new(window.app_handle().clone()); dl.download(ModelTypeVosk::from_str(&model).unwrap()) }); } #[tauri::command] fn list_devices_command() -> Vec { - module::device::list_devices() + device::list_devices() } #[tauri::command] fn has_accessibility_permission_command() -> bool { - module::permissions::has_accessibility_permission() + permissions::has_accessibility_permission() } #[tauri::command] fn has_screen_capture_permission_command(window: Window) -> bool { - module::permissions::has_screen_capture_permission(window) + permissions::has_screen_capture_permission(window) } #[tauri::command] fn has_microphone_permission_command(window: Window) -> bool { - module::permissions::has_microphone_permission(window) + permissions::has_microphone_permission(window) } #[tauri::command] @@ -87,39 +98,31 @@ fn start_command( *lock = Some(stop_record_tx); std::thread::spawn(move || { if device_type == "microphone" { - let record = module::record::Record::new(window.app_handle().clone()); + let record = Record::new(window.app_handle().clone()); record.start( device_label, speaker_language, transcription_accuracy, note_id, stop_record_rx, - Arc::new(Mutex::new(false)), ); } else if device_type == "desktop" { - let record_desktop = - module::record_desktop::RecordDesktop::new(window.app_handle().clone()); + let record_desktop = RecordDesktop::new(window.app_handle().clone()); record_desktop.start( speaker_language, transcription_accuracy, note_id, stop_record_rx, None, - Arc::new(Mutex::new(false)), ); } else { - let record = module::record::Record::new(window.app_handle().clone()); - let record_desktop = - module::record_desktop::RecordDesktop::new(window.app_handle().clone()); + let record = Record::new(window.app_handle().clone()); + let record_desktop = RecordDesktop::new(window.app_handle().clone()); let (stop_record_clone_tx, stop_record_clone_rx) = unbounded(); let speaker_language_clone = speaker_language.clone(); let transcription_accuracy_clone = transcription_accuracy.clone(); - let should_stop_other_transcription = Arc::new(Mutex::new(false)); - let should_stop_other_transcription_clone = - Arc::clone(&should_stop_other_transcription); - std::thread::spawn(move || { record_desktop.start( speaker_language_clone, @@ -127,7 +130,6 @@ fn start_command( note_id, stop_record_rx, Some(stop_record_clone_tx), - should_stop_other_transcription_clone, ); }); record.start( @@ -136,7 +138,6 @@ fn start_command( transcription_accuracy, note_id, stop_record_clone_rx.clone(), - should_stop_other_transcription, ); } }); @@ -164,7 +165,7 @@ fn start_trace_command( std::thread::spawn(move || { if transcription_accuracy.starts_with("online-transcript") { - let mut transcription_online = module::transcription_online::TranscriptionOnline::new( + let mut transcription_online = TranscriptionOnline::new( window.app_handle(), transcription_accuracy, speaker_language, @@ -172,14 +173,10 @@ fn start_trace_command( ); transcription_online.start(stop_convert_rx, true); } else if transcription_accuracy.starts_with("online-chat") { - let mut chat_online = module::chat_online::ChatOnline::new( - window.app_handle(), - speaker_language, - note_id, - ); + let mut chat_online = ChatOnline::new(window.app_handle(), speaker_language, note_id); chat_online.start(stop_convert_rx, true); } else { - let mut transcription = module::transcription::Transcription::new( + let mut transcription = Transcription::new( window.app_handle(), transcription_accuracy, speaker_language, diff --git a/src-tauri/src/module/device.rs b/src-tauri/src/module/device.rs index bbfe921..3f79f3a 100644 --- a/src-tauri/src/module/device.rs +++ b/src-tauri/src/module/device.rs @@ -12,9 +12,13 @@ pub fn list_devices() -> Vec { .input_devices() .unwrap() .filter_map(|device| { - Some(Device { - label: device.name().unwrap(), - }) + if device.name().is_ok() && device.name().unwrap().contains("ZoomAudioDevice") { + None + } else { + Some(Device { + label: device.name().unwrap(), + }) + } }) .collect(); } diff --git a/src-tauri/src/module/permissions.rs b/src-tauri/src/module/permissions.rs index ddfdb33..5873e2e 100644 --- a/src-tauri/src/module/permissions.rs +++ b/src-tauri/src/module/permissions.rs @@ -2,10 +2,11 @@ extern crate objc; extern crate objc_foundation; extern crate objc_id; -use objc::msg_send; -use objc::runtime::{Class, Object}; -use objc::sel; -use objc::sel_impl; +use objc::{ + msg_send, + runtime::{Class, Object}, + sel, sel_impl, +}; use objc_id::Id; use core_graphics::access::ScreenCaptureAccess; @@ -42,7 +43,7 @@ pub fn has_microphone_permission(window: Window) -> bool { pub fn has_screen_capture_permission(window: Window) -> bool { let access = ScreenCaptureAccess::default(); - let trusted = access.request(); + let trusted = access.preflight(); if !trusted { let func = |ok: bool| { if ok { diff --git a/src-tauri/src/module/recognizer.rs b/src-tauri/src/module/recognizer.rs index 640d796..23df6f1 100644 --- a/src-tauri/src/module/recognizer.rs +++ b/src-tauri/src/module/recognizer.rs @@ -29,13 +29,13 @@ impl MyRecognizer { pub fn recognize>( app_handle: AppHandle, - last_partial: &mut String, recognizer: &mut Recognizer, data: &[T], channels: ChannelCount, notify_decoding_state_is_finalized_tx: SyncSender, is_desktop: bool, ) { + let mut last_partial = String::new(); let data: Vec = data.iter().map(|v| v.to_sample()).collect(); let data = if channels != 1 { Self::stereo_to_mono(&data) diff --git a/src-tauri/src/module/record.rs b/src-tauri/src/module/record.rs index 4dfaa41..b0f434d 100644 --- a/src-tauri/src/module/record.rs +++ b/src-tauri/src/module/record.rs @@ -23,8 +23,8 @@ use crossbeam_channel::{unbounded, Receiver}; use tauri::{api::path::data_dir, AppHandle, Manager}; use super::{ - chat_online::ChatOnline, recognizer::MyRecognizer, sqlite::Sqlite, - transcription::Transcription, transcription_online::TranscriptionOnline, writer::Writer, + chat_online::ChatOnline, recognizer::MyRecognizer, sqlite::Sqlite, transcription, + transcription_online::TranscriptionOnline, writer::Writer, }; pub struct Record { @@ -43,16 +43,7 @@ impl Record { transcription_accuracy: String, note_id: u64, stop_record_rx: Receiver<()>, - should_stop_other_transcription: Arc>, ) { - let should_stop_other_transcription_on_record = - *should_stop_other_transcription.lock().unwrap(); - if !should_stop_other_transcription_on_record { - let mut lock = should_stop_other_transcription.lock().unwrap(); - *lock = true; - drop(lock); - } - let host = cpal::default_host(); let device = host .input_devices() @@ -75,7 +66,6 @@ impl Record { ); let recognizer = Arc::new(Mutex::new(recognizer)); let recognizer_clone = recognizer.clone(); - let mut last_partial = String::new(); let spec = Writer::wav_spec_from_config(&config); let data_dir = data_dir().unwrap_or(PathBuf::from("./")); @@ -97,7 +87,6 @@ impl Record { move |data: &[f32], _| { MyRecognizer::recognize( app_handle.clone(), - &mut last_partial, &mut recognizer_clone.lock().unwrap(), data, channels, @@ -113,7 +102,6 @@ impl Record { move |data: &[u16], _| { MyRecognizer::recognize( app_handle.clone(), - &mut last_partial, &mut recognizer_clone.lock().unwrap(), data, channels, @@ -129,7 +117,6 @@ impl Record { move |data: &[i16], _| { MyRecognizer::recognize( app_handle.clone(), - &mut last_partial, &mut recognizer_clone.lock().unwrap(), data, channels, @@ -184,10 +171,7 @@ impl Record { .lock() .unwrap() .replace(Writer::build(&audio_path.to_str().expect("error"), spec)); - if !is_no_transcription - && !*is_converting.lock().unwrap() - && !should_stop_other_transcription_on_record - { + if !is_no_transcription && !*is_converting.lock().unwrap() { let is_converting_clone = Arc::clone(&is_converting); let app_handle_clone = app_handle.clone(); let stop_convert_rx_clone = stop_convert_rx.clone(); @@ -214,13 +198,16 @@ impl Record { ); chat_online.start(stop_convert_rx_clone, false); } else { - let mut transcription = Transcription::new( + transcription::initialize_transcription( app_handle_clone, transcription_accuracy_clone, speaker_language_clone, note_id, ); - transcription.start(stop_convert_rx_clone, false); + let mut lock = transcription::SINGLETON_INSTANCE.lock().unwrap(); + if let Some(singleton) = lock.as_mut() { + singleton.start(stop_convert_rx_clone, false); + } } let mut lock = is_converting_clone.lock().unwrap(); @@ -245,6 +232,7 @@ impl Record { stop_writer_tx.send(()).unwrap(); if !is_no_transcription { stop_convert_tx.send(()).unwrap(); + transcription::drop_transcription(); } else { drop(stop_convert_tx) } diff --git a/src-tauri/src/module/record_desktop.rs b/src-tauri/src/module/record_desktop.rs index 9e93273..094dbb3 100644 --- a/src-tauri/src/module/record_desktop.rs +++ b/src-tauri/src/module/record_desktop.rs @@ -8,11 +8,14 @@ use crate::BUNDLE_IDENTIFIER; use std::{ - fs::remove_file, + fs::{remove_file, File}, + io::BufWriter, + option::Option, path::PathBuf, + string::String, sync::{ mpsc::{sync_channel, SyncSender}, - Arc, Mutex, + Arc, Mutex, Weak, }, thread, }; @@ -22,18 +25,21 @@ use crossbeam_channel::{unbounded, Receiver, Sender}; use hound::{WavSpec, WavWriter}; use tauri::{api::path::data_dir, AppHandle, Manager}; -use screencapturekit::cm_sample_buffer::CMSampleBuffer; -use screencapturekit::sc_content_filter::{InitParams, SCContentFilter}; -use screencapturekit::sc_error_handler::StreamErrorHandler; -use screencapturekit::sc_output_handler::{SCStreamOutputType, StreamOutput}; -use screencapturekit::sc_shareable_content::SCShareableContent; -use screencapturekit::sc_stream::SCStream; -use screencapturekit::sc_stream_configuration::SCStreamConfiguration; +use screencapturekit::{ + cm_sample_buffer::CMSampleBuffer, + sc_content_filter::{InitParams, SCContentFilter}, + sc_error_handler::StreamErrorHandler, + sc_output_handler::{SCStreamOutputType, StreamOutput}, + sc_shareable_content::SCShareableContent, + sc_stream::SCStream, + sc_stream_configuration::SCStreamConfiguration, +}; + use vosk::Recognizer; use super::{ - chat_online::ChatOnline, recognizer::MyRecognizer, sqlite::Sqlite, - transcription::Transcription, transcription_online::TranscriptionOnline, writer::Writer, + chat_online::ChatOnline, recognizer::MyRecognizer, sqlite::Sqlite, transcription, + transcription_online::TranscriptionOnline, writer::Writer, }; pub struct RecordDesktop { @@ -48,18 +54,10 @@ impl StreamErrorHandler for ErrorHandler { } struct StoreAudioHandler { - last_partial: Arc>, app_handle: AppHandle, - recognizer_clone: Arc>, + recognizer: Weak>, notify_decoding_state_is_finalized_tx: SyncSender, - writer_clone: Arc< - std::sync::Mutex< - std::option::Option<( - WavWriter>, - std::string::String, - )>, - >, - >, + writer_clone: Arc>, String)>>>, channels: u16, } @@ -77,15 +75,16 @@ impl StreamOutput for StoreAudioHandler { .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap())) .collect(); - MyRecognizer::recognize( - self.app_handle.clone(), - &mut self.last_partial.lock().unwrap(), - &mut self.recognizer_clone.lock().unwrap(), - &samples, - self.channels, - self.notify_decoding_state_is_finalized_tx.clone(), - true, - ); + if let Some(recognizer) = self.recognizer.upgrade() { + MyRecognizer::recognize( + self.app_handle.clone(), + &mut recognizer.lock().unwrap(), + &samples, + self.channels, + self.notify_decoding_state_is_finalized_tx.clone(), + true, + ); + } Writer::write_input_data::(&samples, &self.writer_clone); } @@ -103,16 +102,7 @@ impl RecordDesktop { note_id: u64, stop_record_rx: Receiver<()>, stop_record_clone_tx: Option>, - should_stop_other_transcription: Arc>, ) { - let should_stop_other_transcription_on_record_desktop = - *should_stop_other_transcription.lock().unwrap(); - if !should_stop_other_transcription_on_record_desktop { - let mut lock = should_stop_other_transcription.lock().unwrap(); - *lock = true; - drop(lock); - } - let mut current = SCShareableContent::current(); let display = current.displays.pop().unwrap(); @@ -134,9 +124,9 @@ impl RecordDesktop { speaker_language.clone(), config.sample_rate as f32, ); - let recognizer = Arc::new(Mutex::new(recognizer)); - let recognizer_clone = recognizer.clone(); - let last_partial = Arc::new(Mutex::new(String::new())); + let recognizer_arc = Arc::new(Mutex::new(recognizer)); + let recognizer_weak = Arc::downgrade(&recognizer_arc); + let spec = WavSpec { channels, sample_rate, @@ -160,9 +150,8 @@ impl RecordDesktop { let mut stream = SCStream::new(filter, config, ErrorHandler); stream.add_output( StoreAudioHandler { - last_partial, app_handle, - recognizer_clone, + recognizer: recognizer_weak, notify_decoding_state_is_finalized_tx, writer_clone, channels, @@ -215,10 +204,7 @@ impl RecordDesktop { .lock() .unwrap() .replace(Writer::build(&audio_path.to_str().expect("error"), spec)); - if !is_no_transcription - && !*is_converting.lock().unwrap() - && !should_stop_other_transcription_on_record_desktop - { + if !is_no_transcription && !*is_converting.lock().unwrap() { let is_converting_clone = Arc::clone(&is_converting); let app_handle_clone = app_handle.clone(); let stop_convert_rx_clone = stop_convert_rx.clone(); @@ -245,13 +231,16 @@ impl RecordDesktop { ); chat_online.start(stop_convert_rx_clone, false); } else { - let mut transcription = Transcription::new( + transcription::initialize_transcription( app_handle_clone, transcription_accuracy_clone, speaker_language_clone, note_id, ); - transcription.start(stop_convert_rx_clone, false); + let mut lock = transcription::SINGLETON_INSTANCE.lock().unwrap(); + if let Some(singleton) = lock.as_mut() { + singleton.start(stop_convert_rx_clone, false); + } } let mut lock = is_converting_clone.lock().unwrap(); @@ -280,6 +269,7 @@ impl RecordDesktop { stop_writer_tx.send(()).unwrap(); if !is_no_transcription { stop_convert_tx.send(()).unwrap(); + transcription::drop_transcription(); } else { drop(stop_convert_tx) } diff --git a/src-tauri/src/module/transcription.rs b/src-tauri/src/module/transcription.rs index 6d6c06a..95c7e26 100644 --- a/src-tauri/src/module/transcription.rs +++ b/src-tauri/src/module/transcription.rs @@ -1,9 +1,9 @@ use super::{sqlite::Sqlite, transcriber::Transcriber}; use crossbeam_channel::Receiver; - use hound::SampleFormat; use samplerate_rs::{convert, ConverterType}; +use std::sync::Mutex; use tauri::{AppHandle, Manager}; use whisper_rs::WhisperContext; @@ -27,7 +27,7 @@ impl Transcription { note_id: u64, ) -> Self { let app_handle_clone = app_handle.clone(); - Self { + Transcription { app_handle, sqlite: Sqlite::new(), ctx: Transcriber::build(app_handle_clone, transcription_accuracy.clone()), @@ -130,10 +130,10 @@ impl Transcription { .expect("failed to get number of segments"); let mut converted: Vec = vec!["".to_string()]; for i in 0..num_segments { - let segment = state - .full_get_segment_text(i) - .expect("failed to get segment"); - converted.push(segment.to_string()); + let segment = state.full_get_segment_text(i); + if segment.is_ok() { + converted.push(segment.unwrap().to_string()); + }; } let updated = self @@ -155,3 +155,27 @@ impl Transcription { }); } } + +pub static SINGLETON_INSTANCE: Mutex> = Mutex::new(None); + +pub fn initialize_transcription( + app_handle: AppHandle, + transcription_accuracy: String, + speaker_language: String, + note_id: u64, +) { + let mut singleton = SINGLETON_INSTANCE.lock().unwrap(); + if singleton.is_none() { + *singleton = Some(Transcription::new( + app_handle, + transcription_accuracy, + speaker_language, + note_id, + )); + } +} + +pub fn drop_transcription() { + let mut singleton = SINGLETON_INSTANCE.lock().unwrap(); + *singleton = None; +} diff --git a/src/components/molecules/SpeakerLanguage.tsx b/src/components/molecules/SpeakerLanguage.tsx index aa33697..8074216 100644 --- a/src/components/molecules/SpeakerLanguage.tsx +++ b/src/components/molecules/SpeakerLanguage.tsx @@ -3,11 +3,13 @@ import { useRecoilState, useRecoilValue } from 'recoil'; import { speakerLanguageState } from "../../store/atoms/speakerLanguageState"; import { modelVoskDownloadedState } from "../../store/atoms/modelVoskDownloadedState"; import { recordState } from "../../store/atoms/recordState"; +import { tracingState } from "../../store/atoms/tracingState"; const SpeakerLanguage = (): JSX.Element => { const downloadedModels = useRecoilValue(modelVoskDownloadedState) const [speakerLanguage, setSpeakerLanguage] = useRecoilState(speakerLanguageState) const isRecording = useRecoilValue(recordState) + const isTracing = useRecoilValue(tracingState); const dropdownRef = useRef(null) const change = (e: ChangeEvent) => { @@ -93,7 +95,7 @@ const SpeakerLanguage = (): JSX.Element => { return (
- {(isRecording || downloadedModels.length === 0) ?