diff --git a/Cargo.lock b/Cargo.lock index 0800c46..0c9a730 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -50,6 +50,21 @@ version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "ansi-to-tui" version = "7.0.0" @@ -177,6 +192,12 @@ dependencies = [ "rustc-demangle", ] +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.22.1" @@ -347,6 +368,26 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + +[[package]] +name = "chrono" +version = "0.4.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-targets 0.52.6", +] + [[package]] name = "clap" version = "4.5.17" @@ -524,6 +565,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "ctrlc" +version = "3.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90eeab0aa92f3f9b4e87f258c72b139c207d251f9cbc1080a0086b86a8870dd3" +dependencies = [ + "nix", + "windows-sys 0.59.0", +] + [[package]] name = "darling" version = "0.20.10" @@ -643,6 +694,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "fdeflate" version = "0.3.4" @@ -1014,6 +1071,29 @@ dependencies = [ "tracing", ] +[[package]] +name = "iana-time-zone" +version = "0.1.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core 0.52.0", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -1121,6 +1201,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "libc" version = "0.2.158" @@ -1210,6 +1296,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1248,6 +1344,18 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "cfg_aliases", + "libc", +] + [[package]] name = "nom" version = "7.1.3" @@ -1521,7 +1629,7 @@ version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42cf17e9a1800f5f396bc67d193dc9411b59012a5876445ef450d449881e1016" dependencies = [ - "base64", + "base64 0.22.1", "indexmap", "quick-xml", "serde", @@ -1738,7 +1846,7 @@ version = "0.12.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "futures-core", "futures-util", @@ -1752,6 +1860,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "once_cell", "percent-encoding", "pin-project-lite", @@ -1765,10 +1874,12 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-rustls", + "tokio-util", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", "windows-registry", @@ -1843,7 +1954,7 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" dependencies = [ - "base64", + "base64 0.22.1", "rustls-pki-types", ] @@ -2175,6 +2286,19 @@ dependencies = [ "libc", ] +[[package]] +name = "tempfile" +version = "3.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" +dependencies = [ + "cfg-if", + "fastrand", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "tenere" version = "0.11.2" @@ -2182,11 +2306,17 @@ dependencies = [ "ansi-to-tui", "arboard", "async-trait", + "base64 0.13.1", "bat", + "bytes", + "chrono", "clap", "crossterm", + "ctrlc", "dirs", "futures", + "lazy_static", + "libc", "ratatui", "regex", "reqwest", @@ -2194,8 +2324,10 @@ dependencies = [ "serde_json", "strum", "strum_macros", + "tempfile", "tokio", "toml", + "toml_edit 0.21.1", "tui-textarea", "unicode-width 0.2.0", ] @@ -2363,6 +2495,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml" version = "0.8.19" @@ -2373,7 +2518,7 @@ dependencies = [ "serde", "serde_spanned", "toml_datetime", - "toml_edit", + "toml_edit 0.22.20", ] [[package]] @@ -2385,6 +2530,17 @@ dependencies = [ "serde", ] +[[package]] +name = "toml_edit" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1" +dependencies = [ + "indexmap", + "toml_datetime", + "winnow 0.5.40", +] + [[package]] name = "toml_edit" version = "0.22.20" @@ -2395,7 +2551,7 @@ dependencies = [ "serde", "serde_spanned", "toml_datetime", - "winnow", + "winnow 0.6.18", ] [[package]] @@ -2461,6 +2617,12 @@ dependencies = [ "unicode-width 0.2.0", ] +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-bidi" version = "0.3.15" @@ -2644,6 +2806,19 @@ version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +[[package]] +name = "wasm-streams" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e072d4e72f700fb3443d8fe94a39315df013eef1104903cdb0a2abd322bbecd" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.70" @@ -2715,7 +2890,16 @@ version = "0.56.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1de69df01bdf1ead2f4ac895dc77c9351aefff65b2f3db429a343f9cbf05e132" dependencies = [ - "windows-core", + "windows-core 0.56.0", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ "windows-targets 0.52.6", ] @@ -2940,6 +3124,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "0.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f593a95398737aeed53e489c785df13f3618e41dbcd6718c6addbf1395aa6876" +dependencies = [ + "memchr", +] + [[package]] name = "winnow" version = "0.6.18" diff --git a/Cargo.toml b/Cargo.toml index b75498b..577a86e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,8 @@ futures = "0.3" reqwest = { version = "0.12", default-features = false, features = [ "json", "rustls-tls", + "stream", # For byte streaming + "multipart", # Add this feature for form uploads ] } ratatui = { version = "0.29", features = ["all-widgets"] } regex = "1" @@ -32,6 +34,14 @@ tokio = { version = "1", features = ["full"] } toml = { version = "0.8" } tui-textarea = "0.7" unicode-width = "0.2" +base64 = "0.13" +tempfile = "3" +bytes = "1.5.0" +chrono = "0.4" # For timestamping debug logs +toml_edit = "0.21.0" +lazy_static = "1.4.0" +ctrlc = "3.4.0" +libc = "0.2.142" # For process killing on Unix [profile.release] lto = "fat" diff --git a/README.md b/README.md index fd21747..475b83b 100644 --- a/README.md +++ b/README.md @@ -171,6 +171,7 @@ Include your API key in the configuration file: openai_api_key = "Your API key here" model = "gpt-3.5-turbo" url = "https://api.openai.com/v1/chat/completions" +system_prompt = "You are a helpful assistant." ``` The default model is set to `gpt-3.5-turbo`. Check out the [OpenAI documentation](https://platform.openai.com/docs/models/gpt-3-5) for more info. diff --git a/example_config.toml b/example_config.toml new file mode 100644 index 0000000..71d46cf --- /dev/null +++ b/example_config.toml @@ -0,0 +1,16 @@ +llm = "chatgpt" + +[chatgpt] +openai_api_key = "your-api-key-here" +model = "gpt-4-turbo" +system_prompt = "You are an AI assistant that specializes in programming and software development." + +[key_bindings] +show_help = '?' +show_history = 'h' +new_chat = 'n' +stop_stream = 't' +load_voice = 'v' + +[tts] +url = "http://0.0.0.0:8000/v1/audio/speech" diff --git a/scripts/test_tts.sh b/scripts/test_tts.sh new file mode 100644 index 0000000..2301173 --- /dev/null +++ b/scripts/test_tts.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# A test script to verify your audio playback system works + +echo "Testing audio playback system..." + +# Check for media players +echo "Checking for media players..." +which mpv >/dev/null && echo "✓ mpv found" || echo "✗ mpv not found" +which ffplay >/dev/null && echo "✓ ffplay found" || echo "✗ ffplay not found" +which aplay >/dev/null && echo "✓ aplay found" || echo "✗ aplay not found" + +# Generate a test tone using sox (if available) +if which sox >/dev/null; then + echo "Generating test tone with sox..." + sox -n /tmp/test_tone.mp3 synth 2 sine 440 + + # Try to play with each player + echo "Playing with mpv..." + mpv /tmp/test_tone.mp3 --no-terminal >/dev/null 2>&1 && echo "✓ mpv playback works" || echo "✗ mpv playback failed" + + if which ffplay >/dev/null; then + echo "Playing with ffplay..." + ffplay -nodisp -autoexit -loglevel quiet /tmp/test_tone.mp3 >/dev/null 2>&1 && echo "✓ ffplay playback works" || echo "✗ ffplay playback failed" + fi + + if which aplay >/dev/null; then + # Convert to wav for aplay + sox /tmp/test_tone.mp3 /tmp/test_tone.wav + echo "Playing with aplay..." + aplay /tmp/test_tone.wav >/dev/null 2>&1 && echo "✓ aplay playback works" || echo "✗ aplay playback failed" + fi +else + echo "Sox not found, skipping audio playback tests" + echo "Install sox with: sudo apt-get install sox (Debian/Ubuntu)" +fi + +# Test API endpoint +echo "Testing TTS API endpoint..." +curl -s "http://0.0.0.0:8000/v1/audio/models" | grep model && echo "✓ API is responding" || echo "✗ API not responding" + +echo "Done!" diff --git a/src/chatgpt.rs b/src/chatgpt.rs index 6b1d53f..92e850a 100644 --- a/src/chatgpt.rs +++ b/src/chatgpt.rs @@ -20,6 +20,7 @@ pub struct ChatGPT { openai_api_key: String, model: String, url: String, + system_prompt: String, messages: Vec>, } @@ -45,6 +46,7 @@ You need to define one whether in the configuration file or as an environment va openai_api_key, model: config.model, url: config.url, + system_prompt: config.system_prompt, messages: Vec::new(), } } @@ -80,7 +82,7 @@ impl LLM for ChatGPT { ("role".to_string(), "system".to_string()), ( "content".to_string(), - "You are a helpful assistant.".to_string(), + self.system_prompt.clone(), ), ])), ]; @@ -131,7 +133,7 @@ impl LLM for ChatGPT { sender.send(Event::LLMEvent(LLMAnswer::Answer(msg.to_string())))?; } - sleep(Duration::from_millis(100)).await; + sleep(Duration::from_millis(1)).await; } } } diff --git a/src/config.rs b/src/config.rs index 00cb77e..bd5bff0 100644 --- a/src/config.rs +++ b/src/config.rs @@ -19,6 +19,9 @@ pub struct Config { pub llamacpp: Option, pub ollama: Option, + + #[serde(default)] + pub tts: TTSConfig, } pub fn default_llm_backend() -> LLMBackend { @@ -35,6 +38,9 @@ pub struct ChatGPTConfig { #[serde(default = "ChatGPTConfig::default_url")] pub url: String, + + #[serde(default = "ChatGPTConfig::default_system_prompt")] + pub system_prompt: String, } impl Default for ChatGPTConfig { @@ -43,6 +49,7 @@ impl Default for ChatGPTConfig { openai_api_key: None, model: Self::default_model(), url: Self::default_url(), + system_prompt: Self::default_system_prompt(), } } } @@ -55,6 +62,10 @@ impl ChatGPTConfig { pub fn default_url() -> String { String::from("https://api.openai.com/v1/chat/completions") } + + pub fn default_system_prompt() -> String { + String::from("You are a helpful assistant.") + } } // LLamacpp @@ -73,6 +84,31 @@ pub struct OllamaConfig { pub model: String, } +// TTS +#[derive(Deserialize, Debug, Clone)] +pub struct TTSConfig { + #[serde(default = "TTSConfig::default_url")] + pub url: String, + + #[serde(default)] + pub default_voice: Option, +} + +impl Default for TTSConfig { + fn default() -> Self { + Self { + url: Self::default_url(), + default_voice: None, + } + } +} + +impl TTSConfig { + pub fn default_url() -> String { + String::from("http://0.0.0.0:8000/v1/audio/speech") + } +} + #[derive(Deserialize, Debug)] pub struct KeyBindings { #[serde(default = "KeyBindings::default_show_help")] @@ -86,6 +122,9 @@ pub struct KeyBindings { #[serde(default = "KeyBindings::default_stop_stream")] pub stop_stream: char, + + #[serde(default = "KeyBindings::default_load_voice")] + pub load_voice: char, } impl Default for KeyBindings { @@ -95,6 +134,7 @@ impl Default for KeyBindings { show_history: 'h', new_chat: 'n', stop_stream: 't', + load_voice: 'v', } } } @@ -115,6 +155,10 @@ impl KeyBindings { fn default_stop_stream() -> char { 't' } + + fn default_load_voice() -> char { + 'v' + } } impl Config { @@ -129,7 +173,7 @@ impl Config { }; let config = std::fs::read_to_string(conf_path).unwrap_or_default(); - let app_config: Config = toml::from_str(&config).unwrap(); + let mut app_config: Config = toml::from_str(&config).unwrap(); if app_config.llm == LLMBackend::LLamacpp && app_config.llamacpp.is_none() { eprintln!("Config for LLamacpp is not provided"); @@ -140,6 +184,20 @@ impl Config { eprintln!("Config for Ollama is not provided"); std::process::exit(1) } + + // Try to load saved default voice from file if one exists + let voice_file = dirs::config_dir() + .unwrap() + .join("tenere") + .join("default_voice.txt"); + + if voice_file.exists() { + if let Ok(voice_id) = std::fs::read_to_string(&voice_file) { + if !voice_id.trim().is_empty() { + app_config.tts.default_voice = Some(voice_id.trim().to_string()); + } + } + } app_config } diff --git a/src/event.rs b/src/event.rs index de62f09..4b53df3 100644 --- a/src/event.rs +++ b/src/event.rs @@ -15,6 +15,14 @@ pub enum Event { Resize(u16, u16), LLMEvent(LLMAnswer), Notification(Notification), + TTSEvent(TTSEvent), +} + +#[derive(Debug, Clone)] +pub enum TTSEvent { + PlayText { text: String, voice: Option }, + Complete, + Error(String), } #[allow(dead_code)] diff --git a/src/handler.rs b/src/handler.rs index 91c864c..3fa77be 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,5 +1,7 @@ use crate::llm::{LLMAnswer, LLMRole}; use crate::{chat::Chat, prompt::Mode}; +use crate::event::TTSEvent; +use crate::config::{TTSConfig, Config}; // Add Config import use crate::{ app::{App, AppResult, FocusedBlock}, @@ -16,6 +18,10 @@ use tokio::sync::Mutex; use tokio::sync::mpsc::UnboundedSender; +use crate::tts; +use tokio::fs; +use crate::notification::{Notification, NotificationLevel}; + pub async fn handle_key_events( key_event: KeyEvent, app: &mut App<'_>, @@ -38,6 +44,46 @@ pub async fn handle_key_events( .store(true, std::sync::atomic::Ordering::Relaxed); } + // Read the current response with TTS + KeyCode::Char('l') if key_event.modifiers.contains(KeyModifiers::CONTROL) => { + // Play the current answer with TTS + if !app.chat.answer.plain_answer.is_empty() { + sender.send(Event::TTSEvent(TTSEvent::PlayText { + text: app.chat.answer.plain_answer.clone(), + voice: None, + }))?; + } + } + + // Load voice for TTS + KeyCode::Char(c) if c == app.config.key_bindings.load_voice && + key_event.modifiers.contains(KeyModifiers::CONTROL) => { + // Spawn an async task to handle voice loading + let sender_clone = sender.clone(); + // Pass the actual app config here + let config_clone = Arc::clone(&app.config); + tokio::spawn(async move { + match load_voice_file(sender_clone.clone(), config_clone).await { + Ok(voice_id) => { + sender_clone.send(Event::Notification( + Notification::new( + format!("Voice loaded successfully: {}", voice_id), + NotificationLevel::Info + ) + )).unwrap_or_default(); + }, + Err(e) => { + sender_clone.send(Event::Notification( + Notification::new( + format!("Error loading voice: {}", e), + NotificationLevel::Error + ) + )).unwrap_or_default(); + } + } + }); + }, + // scroll down KeyCode::Char('j') | KeyCode::Down => match app.focused_block { FocusedBlock::History => { @@ -250,3 +296,235 @@ pub async fn handle_key_events( Ok(()) } + +/// Load a voice file from the configured directory and update the config +/// Cycles through available voices each time it's called +async fn load_voice_file( + sender: UnboundedSender, + config: Arc +) -> Result> { + // Get the voice directory + let voice_dir = tts::get_voice_dir()?; + + // Read all files in the directory + let mut entries = fs::read_dir(&voice_dir).await?; + let mut voice_files = Vec::new(); + + // Collect all audio files + while let Some(entry) = entries.next_entry().await? { + let path = entry.path(); + if path.is_file() { + // Only include files with audio extensions + if let Some(ext) = path.extension() { + let ext_str = ext.to_string_lossy().to_lowercase(); + if ["mp3", "wav", "ogg", "m4a", "flac"].contains(&ext_str.as_str()) { + voice_files.push(path); + } + } + } + } + + // If there are no voice files, return an error + if voice_files.is_empty() { + return Err(format!("No voice files found in {:?}. Place audio files in this directory.", voice_dir).into()); + } + + // Sort the files to ensure consistent order + voice_files.sort(); + + // Get the last used voice file index + let last_index_file = dirs::config_dir().unwrap().join("tenere").join("last_voice_index"); + let last_index = if last_index_file.exists() { + match tokio::fs::read_to_string(&last_index_file).await { + Ok(content) => content.trim().parse::().unwrap_or(0), + Err(_) => 0 + } + } else { + 0 + }; + + // Calculate the next index (cycling through the list) + let next_index = (last_index + 1) % voice_files.len(); + + // Save the next index for future calls + if let Some(parent) = last_index_file.parent() { + if !parent.exists() { + tokio::fs::create_dir_all(parent).await?; + } + } + tokio::fs::write(&last_index_file, next_index.to_string()).await?; + + // Get the selected voice file + let voice_path = &voice_files[next_index]; + let file_name = voice_path.file_name().unwrap().to_string_lossy().to_string(); + + // Create a more reliable cache key using file name and file size + let file_metadata = tokio::fs::metadata(voice_path).await?; + let file_size = file_metadata.len(); + let cache_key = format!("{}_size_{}", file_name, file_size); + + // Debug the voice file selection + // eprintln!("Selected voice file: {} (size: {} bytes)", file_name, file_size); + + // Check if we have a cached voice ID for this file to avoid re-uploading + let cache_file = dirs::config_dir().unwrap().join("tenere").join("voice_cache.json"); + let mut voice_id = None; + + // Try to get the voice ID from cache first + if cache_file.exists() { + // eprintln!("Voice cache file exists at: {:?}", cache_file); + + match tokio::fs::read_to_string(&cache_file).await { + Ok(content) => { + // eprintln!("Read cache content: {} bytes", content.len()); + // Parse as JSON map directly - more robust error handling + match serde_json::from_str::>(&content) { + Ok(cache_map) => { + // First try with the cache_key + if let Some(id) = cache_map.get(&cache_key).and_then(|v| v.as_str()) { + voice_id = Some(id.to_string()); + // eprintln!("Found voice ID in cache with key {}: {}", cache_key, id); + } + // Fallback to just the filename + else if let Some(id) = cache_map.get(&file_name).and_then(|v| v.as_str()) { + voice_id = Some(id.to_string()); + // eprintln!("Found voice ID in cache with filename {}: {}", file_name, id); + } else { + // eprintln!("No cache entry found for {} or {}", cache_key, file_name); + } + }, + Err(e) => { + sender.send(Event::Notification( + Notification::new( + format!("Failed to parse voice cache: {}", e), + NotificationLevel::Error + ) + ))?; + // eprintln!("Failed to parse voice cache: {}", e); + } + } + }, + Err(e) => { + sender.send(Event::Notification( + Notification::new( + format!("Failed to read voice cache file: {}", e), + NotificationLevel::Error + ) + ))?; + // eprintln!("Failed to read voice cache file: {}", e); + } + } + } else { + // eprintln!("Voice cache file doesn't exist yet at: {:?}", cache_file); + } + + // If not found in cache, upload the file + let voice_id = if let Some(id) = voice_id { + // Voice found in cache, notify the user + sender.send(Event::Notification( + Notification::new( + format!("Using voice: {} ({}/{})", + file_name, next_index + 1, voice_files.len()), + NotificationLevel::Info + ) + ))?; + id + } else { + // Voice not found in cache, upload it + // eprintln!("No cached voice found, uploading file: {}", file_name); + + // Upload the voice file and get the voice ID + let id = tts::upload_voice_file(voice_path, &config.tts).await?; + // eprintln!("Voice uploaded successfully with ID: {}", id); + + // Create the cache map + let mut cache_map = if cache_file.exists() { + match tokio::fs::read_to_string(&cache_file).await { + Ok(content) => match serde_json::from_str::>(&content) { + Ok(map) => map, + Err(_) => { + // If parsing fails, create a fresh map + // eprintln!("Cache file exists but couldn't be parsed, creating new one"); + serde_json::Map::new() + } + }, + Err(_) => serde_json::Map::new() + } + } else { + serde_json::Map::new() + }; + + // Add both the filename and the cache_key entries + cache_map.insert(file_name.clone(), serde_json::Value::String(id.clone())); + cache_map.insert(cache_key.clone(), serde_json::Value::String(id.clone())); + + let cache_content = serde_json::to_string_pretty(&cache_map)?; + + // Make sure the directory exists + let parent = cache_file.parent().unwrap(); + if !parent.exists() { + tokio::fs::create_dir_all(parent).await?; + } + + // Write the updated cache + match tokio::fs::write(&cache_file, &cache_content).await { + Ok(_) => sender.send(Event::Notification( + Notification::new( + "Voice cache updated successfully".to_string(), + NotificationLevel::Info + ) + )).unwrap_or_default(), + Err(e) => sender.send(Event::Notification( + Notification::new( + format!("Failed to write voice cache file: {}", e), + NotificationLevel::Error + ) + )).unwrap_or_default(), + } + + // Send notification that we're uploading a new voice + sender.send(Event::Notification( + Notification::new( + format!("Uploaded new voice: {} ({}/{})", + file_name, next_index + 1, voice_files.len()), + NotificationLevel::Info + ) + ))?; + + id + }; + + // Update the config file + let config_dir = dirs::config_dir().unwrap().join("tenere"); + let config_path = config_dir.join("config.toml"); + + // Read the existing config + let config_content = match tokio::fs::read_to_string(&config_path).await { + Ok(content) => content, + Err(_) => String::new() + }; + + // Parse it as a document to preserve formatting and comments + let mut doc = match config_content.parse::() { + Ok(doc) => doc, + Err(_) => toml_edit::Document::new() + }; + + // Update the voice in the config file + if !doc.as_table().contains_key("tts") { + doc["tts"] = toml_edit::Item::Table(toml_edit::Table::new()); + } + doc["tts"]["default_voice"] = toml_edit::value(voice_id.clone()); + + // Write the config back + tokio::fs::write(&config_path, doc.to_string()).await?; + + // Update the in-memory config too + let tts_config_ptr = &config.tts as *const TTSConfig as *mut TTSConfig; + unsafe { + (*tts_config_ptr).default_voice = Some(voice_id.clone()); + } + + // Return the voice ID + Ok(voice_id) +} diff --git a/src/lib.rs b/src/lib.rs index 24035c8..0c8d6b9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,3 +31,5 @@ pub mod chat; pub mod llamacpp; pub mod ollama; + +pub mod tts; \ No newline at end of file diff --git a/src/llm.rs b/src/llm.rs index 5da3750..73a13fc 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -57,4 +57,5 @@ impl LLMModel { LLMBackend::Ollama => Box::new(Ollama::new(config.ollama.clone().unwrap())), } } -} + +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 96fe544..8b8cb26 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,12 +2,13 @@ use ratatui::backend::CrosstermBackend; use ratatui::Terminal; use std::{env, io, path::PathBuf}; use tenere::app::{App, AppResult}; -use tenere::config::Config; -use tenere::event::{Event, EventHandler}; +use tenere::config::{Config, TTSConfig}; +use tenere::event::{Event, EventHandler, TTSEvent}; use tenere::formatter::Formatter; use tenere::handler::handle_key_events; -use tenere::llm::{LLMAnswer, LLMRole}; +use tenere::llm::{LLMAnswer, LLMRole, LLM}; // Add LLM import use tenere::tui::Tui; +use tenere::tts; use tenere::llm::LLMModel; @@ -16,6 +17,8 @@ use tokio::sync::Mutex; use clap::{crate_description, crate_version, Arg, Command}; +use ratatui::backend::Backend; // Add this import + #[tokio::main] async fn main() -> AppResult<()> { let matches = Command::new("tenere") @@ -55,41 +58,103 @@ async fn main() -> AppResult<()> { // load potential history data from archive files app.history.load_history(tui.events.sender.clone()); + // Make sure to clean up TTS processes on exit + let result = run_app(&mut app, llm, &mut tui, &formatter, &config).await; + + // Clean up TTS processes before exiting + tts::kill_all_tts_processes(); + + tui.exit()?; + result +} + +async fn run_app( + app: &mut App<'_>, + llm: Arc>>, + tui: &mut Tui, + formatter: &Formatter<'_>, + config: &Arc +) -> AppResult<()> { while app.running { - tui.draw(&mut app)?; + tui.draw(app)?; match tui.events.next().await? { Event::Tick => app.tick(), Event::Key(key_event) => { - handle_key_events(key_event, &mut app, llm.clone(), tui.events.sender.clone()) + handle_key_events(key_event, app, llm.clone(), tui.events.sender.clone()) .await?; } Event::Mouse(_) => {} Event::Resize(_, _) => {} Event::LLMEvent(LLMAnswer::Answer(answer)) => { app.chat - .handle_answer(LLMAnswer::Answer(answer), &formatter); + .handle_answer(LLMAnswer::Answer(answer.clone()), formatter); + + // We don't want to trigger TTS for every tiny chunk + // Only send longer message portions to avoid choppy audio + if answer.len() > 80 && answer.contains('.') { + tui.events.sender.send(Event::TTSEvent(TTSEvent::PlayText { + text: answer, + voice: None + }))?; + } } Event::LLMEvent(LLMAnswer::EndAnswer) => { { let mut llm = llm.lock().await; llm.append_chat_msg(app.chat.answer.plain_answer.clone(), LLMRole::ASSISTANT); + + // Play the full response with TTS when it completes, + // using the default voice from config if set. + let final_answer = app.chat.answer.plain_answer.clone(); + if !final_answer.is_empty() { + tui.events.sender.send(Event::TTSEvent(TTSEvent::PlayText { + text: final_answer, + voice: config.tts.default_voice.clone(), // Optional default voice + }))?; + } } - - app.chat.handle_answer(LLMAnswer::EndAnswer, &formatter); + app.chat.handle_answer(LLMAnswer::EndAnswer, formatter); app.terminate_response_signal .store(false, std::sync::atomic::Ordering::Relaxed); } Event::LLMEvent(LLMAnswer::StartAnswer) => { app.spinner.active = false; - app.chat.handle_answer(LLMAnswer::StartAnswer, &formatter); + app.chat.handle_answer(LLMAnswer::StartAnswer, formatter); } Event::Notification(notification) => { app.notifications.push(notification); } + Event::TTSEvent(tts_event) => { + handle_tts_event(tts_event, &config.tts).await; + } } } - tui.exit()?; Ok(()) } + +async fn handle_tts_event(event: TTSEvent, tts_config: &TTSConfig) { + match event { + TTSEvent::PlayText { text, voice: _ } => { + // Clone what we need to move into the background task + let tts_config = tts_config.clone(); + let text = text.clone(); + + // Spawn a background task for TTS playback to avoid blocking the UI + tokio::spawn(async move { + if let Err(e) = tts::play_tts(&text, &tts_config).await { + eprintln!("TTS error: {}", e); + } + }); + + // Return immediately so the main application can continue handling input + }, + TTSEvent::Complete => { + // TTS playback completed + }, + TTSEvent::Error(err) => { + eprintln!("TTS error: {}", err); + } + } +} diff --git a/src/tts.rs b/src/tts.rs new file mode 100644 index 0000000..8f24a85 --- /dev/null +++ b/src/tts.rs @@ -0,0 +1,413 @@ +use std::error::Error; +use std::process::Stdio; +use tokio::io::AsyncWriteExt; +use serde::Serialize; +use futures::StreamExt; +use reqwest::Client; +use reqwest::header; +use tokio::process::Command as TokioCommand; +use crate::config::TTSConfig; +use std::sync::{Arc, Mutex, Once}; +use lazy_static::lazy_static; +use std::collections::HashSet; + +// Debug helper macro - you can remove this after debugging +macro_rules! debug { + ($($arg:tt)*) => { + // Log to a file for debugging + if let Ok(mut file) = std::fs::OpenOptions::new().create(true).append(true).open("/tmp/tenere_tts_debug.log") { + use std::io::Write; + let _ = writeln!(&mut file, "[{}] {}", + chrono::Local::now().format("%H:%M:%S%.3f"), + format!($($arg)*)); + } + }; +} + +/// Request structure for the new TTS API +#[derive(Debug, Serialize)] +struct TTSRequest { + model: String, + input: String, + #[serde(skip_serializing_if = "Option::is_none")] + voice: Option, + speed: f32, + language: String, + #[serde(skip_serializing_if = "Option::is_none")] + emotion: Option, + response_format: String, +} + +/// Upload a voice file to be used as a custom TTS voice +pub async fn upload_voice_file(file_path: &std::path::Path, tts_config: &TTSConfig) -> Result> { + debug!("Uploading voice file: {:?}", file_path); + + // Extract the filename without extension to use as voice name + let file_name = file_path.file_stem() + .and_then(|os_str| os_str.to_str()) + .unwrap_or("default_voice"); + + // Check if file exists + if !file_path.exists() { + return Err(format!("Voice file not found: {:?}", file_path).into()); + } + + // Read the file content + let file_content = tokio::fs::read(file_path).await?; + debug!("Read voice file with {} bytes", file_content.len()); + + // Construct the voice API endpoint URL from the base TTS URL + let base_url = tts_config.url.trim_end_matches("/speech").trim_end_matches("/"); + let voice_url = format!("{}/voice", base_url); + debug!("Using voice API endpoint: {}", voice_url); + + // Since we don't have multipart feature enabled, we'll use curl command-line instead + let file_path_str = file_path.to_string_lossy(); + let output = tokio::process::Command::new("curl") + .args([ + "-X", "POST", + "-F", &format!("file=@{}", file_path_str), + "-F", &format!("name={}", file_name), + &voice_url + ]) + .output() + .await?; + + if !output.status.success() { + let error = String::from_utf8_lossy(&output.stderr); + debug!("Voice upload failed: {}", error); + return Err(format!("Voice upload failed: {}", error).into()); + } + + // Parse the JSON response to get the voice ID + let response_json = String::from_utf8_lossy(&output.stdout); + let response: serde_json::Value = serde_json::from_str(&response_json)?; + + let voice_id = response["voice_id"].as_str() + .ok_or("Invalid response: missing voice_id field")? + .to_string(); + + debug!("Successfully uploaded voice with ID: {}", voice_id); + Ok(voice_id) +} + +/// Play text through TTS service with pure streaming (no file storage) +pub async fn play_tts(text: &str, tts_config: &TTSConfig) -> Result<(), Box> { + debug!("TTS request for text: {}", text); + + // Add a terminal bell to indicate TTS is starting (optional) + print!("\x07"); // Bell character + + // Skip empty or whitespace-only text + let text = text.trim(); + if text.is_empty() { + debug!("Skipping TTS for empty text"); + return Ok(()); + } + + // Build the request with the API parameters + let request = TTSRequest { + model: "Zyphra/Zonos-v0.1-transformer".to_string(), + input: text.to_string(), + voice: tts_config.default_voice.clone(), // Use default voice if configured + speed: 1.0, + language: "en-us".to_string(), + emotion: None, + response_format: "mp3".to_string(), + }; + + debug!("Sending request to TTS API at: {}", tts_config.url); + + // Send request to TTS service using the configured URL + let client = Client::new(); + let response = client.post(&tts_config.url) + .json(&request) + .send() + .await?; + + let status = response.status(); + debug!("Got response with status: {}", status); + + if !status.is_success() { + let error_text = response.text().await?; + debug!("Error response: {}", error_text); + return Err(format!("TTS request failed with status: {}, body: {}", status, error_text).into()); + } + + // Get the content type to pass to player + let content_type = response.headers() + .get(header::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .unwrap_or("audio/mp3") + .to_string(); + + debug!("Content type: {}", content_type); + + // Stream the audio directly to the player + stream_audio(response, &content_type).await +} + +/// Stream audio data directly to a player +async fn stream_audio( + response: reqwest::Response, + content_type: &str +) -> Result<(), Box> { + debug!("Starting audio streaming"); + + // Set up a suitable player based on what's available + debug!("Setting up audio player"); + let (mut player_child, mut player_stdin) = match setup_streaming_player(content_type) { + Ok(player) => player, + Err(e) => { + debug!("Player setup failed: {}", e); + return Err(e); + } + }; + + // Process chunks as they arrive + let mut stream = stream_helpers::get_stream(response); + let mut total_bytes = 0; + let mut chunk_count = 0; + + debug!("Starting to receive audio chunks"); + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + chunk_count += 1; + total_bytes += chunk.len(); + debug!("Received chunk #{} - {} bytes", chunk_count, chunk.len()); + + // Write directly to player's stdin + if let Err(e) = player_stdin.write_all(&chunk).await { + debug!("Error writing to player: {}", e); + return Err(e.into()); + } + }, + Err(e) => { + debug!("Error in stream: {}", e); + return Err(e.into()); + } + } + } + + debug!("All chunks received. Total: {} chunks, {} bytes", chunk_count, total_bytes); + + // Close stdin to signal end of input + drop(player_stdin); + debug!("Closed stdin, waiting for player to finish"); + + // Wait for player to finish + let pid = player_child.id(); // Save PID before waiting + let status = player_child.wait().await?; + + // Unregister the process since it's done + if let Some(pid_val) = pid { + unregister_tts_process(pid_val); + } + + if !status.success() { + let code = status.code().unwrap_or(-1); + debug!("Player exited with error code: {}", code); + return Err(format!("Audio player exited with code {}", code).into()); + } + + debug!("Audio playback completed successfully"); + Ok(()) +} + +// Helper function to get a stream from response +mod stream_helpers { + use futures::Stream; + use futures::StreamExt; + use std::pin::Pin; + + pub fn get_stream( + response: reqwest::Response + ) -> Pin, reqwest::Error>> + Send>> { + Box::pin(response.bytes_stream().map(|result| { + result.map(|bytes| bytes.to_vec()) + })) + } +} + +/// Set up a streaming audio player based on what's available +fn setup_streaming_player(content_type: &str) -> Result<(tokio::process::Child, tokio::process::ChildStdin), Box> { + // Make sure cleanup is registered before creating any new processes + register_cleanup(); + + // Try to find which players are available on the system + let mpv_available = std::process::Command::new("mpv").arg("--version").output().is_ok(); + let ffplay_available = std::process::Command::new("ffplay").arg("-version").output().is_ok(); + let aplay_available = std::process::Command::new("aplay").arg("--version").output().is_ok(); + + debug!("Available players: mpv={}, ffplay={}, aplay={}", + mpv_available, ffplay_available, aplay_available); + + // Try mpv first (most versatile) + if mpv_available { + debug!("Trying to use mpv for playback"); + let mut command = TokioCommand::new("mpv") + .args(["-", "--no-cache", "--no-terminal", "--audio-buffer=0.1"]) + .stdin(Stdio::piped()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .spawn()?; + + // Register the process for cleanup + if let Some(pid) = command.id() { + register_tts_process(pid); + } + + let stdin = command.stdin.take() + .ok_or_else(|| "Failed to open mpv stdin".to_string())?; + debug!("Successfully started mpv"); + return Ok((command, stdin)); + } + + // Try ffplay as second option + if ffplay_available { + debug!("Trying to use ffplay for playback"); + let mut command = TokioCommand::new("ffplay") + .args(["-i", "pipe:0", "-autoexit", "-nodisp", "-hide_banner", "-loglevel", "quiet"]) + .stdin(Stdio::piped()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .spawn()?; + + // Register the process for cleanup + if let Some(pid) = command.id() { + register_tts_process(pid); + } + + let stdin = command.stdin.take() + .ok_or_else(|| "Failed to open ffplay stdin".to_string())?; + debug!("Successfully started ffplay"); + return Ok((command, stdin)); + } + + // For aplay (Linux) - only works with WAV + if aplay_available && content_type.contains("wav") { + debug!("Trying to use aplay for playback"); + let mut command = TokioCommand::new("aplay") + .stdin(Stdio::piped()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .spawn()?; + + // Register the process for cleanup + if let Some(pid) = command.id() { + register_tts_process(pid); + } + + let stdin = command.stdin.take() + .ok_or_else(|| "Failed to open aplay stdin".to_string())?; + debug!("Successfully started aplay"); + return Ok((command, stdin)); + } + + debug!("No suitable player found!"); + Err("No suitable streaming audio player found. Please install mpv, ffplay, or aplay.".into()) +} + +/// Helper function to get the default voice file directory +pub fn get_voice_dir() -> Result> { + let voice_dir = dirs::config_dir() + .ok_or_else(|| "Failed to find config directory")? + .join("tenere") + .join("audio"); + + // Create directory if it doesn't exist + if !voice_dir.exists() { + std::fs::create_dir_all(&voice_dir)?; + } + + Ok(voice_dir) +} + +/// Load a voice from a file in the config directory and set as default +pub async fn load_voice_from_file(file_name: &str, tts_config: &mut TTSConfig) -> Result> { + let voice_dir = get_voice_dir()?; + let file_path = voice_dir.join(file_name); + + debug!("Loading voice from file: {:?}", file_path); + + // Upload the voice file + let voice_id = upload_voice_file(&file_path, tts_config).await?; + + // Store the voice ID as the default + tts_config.default_voice = Some(voice_id.clone()); + + Ok(voice_id) +} + +// Add a global registry for tracking spawned processes +lazy_static! { + static ref TTS_PROCESSES: Arc>> = Arc::new(Mutex::new(HashSet::new())); + static ref CLEANUP_REGISTERED: Once = Once::new(); +} + +// Register cleanup handler for program termination +fn register_cleanup() { + CLEANUP_REGISTERED.call_once(|| { + let _processes = TTS_PROCESSES.clone(); + + // Register normal exit cleanup + std::env::set_var("TENERE_CLEANUP_REGISTERED", "true"); + + // Register panic cleanup + let default_hook = std::panic::take_hook(); + std::panic::set_hook(Box::new(move |panic_info| { + kill_all_tts_processes(); + default_hook(panic_info); + })); + + // Register CTRL+C handler + ctrlc::set_handler(move || { + kill_all_tts_processes(); + std::process::exit(0); + }).expect("Error setting Ctrl-C handler"); + }); +} + +// Kill all registered TTS processes +pub fn kill_all_tts_processes() { + debug!("Cleaning up TTS processes"); + if let Ok(mut processes) = TTS_PROCESSES.lock() { + for pid in processes.iter() { + debug!("Killing TTS process with PID: {}", pid); + // Cross-platform way to kill a process by PID + #[cfg(target_os = "windows")] + { + use std::process::Command; + let _ = Command::new("taskkill") + .args(&["/F", "/PID", &pid.to_string()]) + .output(); + } + + #[cfg(not(target_os = "windows"))] + { + // On Unix systems we can use the kill system call + unsafe { + libc::kill(*pid as i32, libc::SIGTERM); + } + } + } + processes.clear(); + } +} + +// Register a new TTS process +fn register_tts_process(pid: u32) { + if let Ok(mut processes) = TTS_PROCESSES.lock() { + processes.insert(pid); + debug!("Registered TTS process: {}", pid); + } +} + +// Unregister a TTS process when it completes +fn unregister_tts_process(pid: u32) { + if let Ok(mut processes) = TTS_PROCESSES.lock() { + processes.remove(&pid); + debug!("Unregistered TTS process: {}", pid); + } +}