diff --git a/.gitignore b/.gitignore index d0709f973..c195e9f4f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # Byte-compiled / optimized / DLL files __pycache__/ .idea/ +.vscode/ *.py[cod] *$py.class diff --git a/main/xiaozhi-server/core/connection.py b/main/xiaozhi-server/core/connection.py index c9a52b120..4d33f5bc2 100644 --- a/main/xiaozhi-server/core/connection.py +++ b/main/xiaozhi-server/core/connection.py @@ -2,11 +2,10 @@ import json import uuid import time -import queue +import copy import asyncio import traceback -import threading import websockets from typing import Dict, Any from plugins_func.loadplugins import auto_import_modules @@ -14,7 +13,6 @@ from core.utils.dialogue import Message, Dialogue from core.handle.textHandle import handleTextMessage from core.utils.util import get_string_no_punctuation_or_emoji, extract_json_from_string, get_ip_info -from concurrent.futures import ThreadPoolExecutor, TimeoutError from core.handle.sendAudioHandle import sendAudioMessage from core.handle.receiveAudioHandle import handleAudioMessage from core.handle.functionHandler import FunctionHandler @@ -50,12 +48,12 @@ def __init__(self, config: Dict[str, Any], _vad, _asr, _llm, _tts, _memory, _int self.client_abort = False self.client_listen_mode = "auto" - # 线程任务相关 - self.loop = asyncio.get_event_loop() - self.stop_event = threading.Event() - self.tts_queue = queue.Queue() - self.audio_play_queue = queue.Queue() - self.executor = ThreadPoolExecutor(max_workers=10) + # 异步任务相关 + self.loop = asyncio.get_running_loop() + self.stop_event = asyncio.Event() + self.tts_queue = asyncio.Queue() + self.audio_play_queue = asyncio.Queue() + self.background_tasks = set() # 依赖的组件 self.vad = _vad @@ -117,7 +115,7 @@ async def handle_connection(self, ws): self.websocket = ws self.session_id = str(uuid.uuid4()) - self.welcome_msg = self.config["xiaozhi"] + self.welcome_msg = copy.deepcopy(self.config["xiaozhi"]) self.welcome_msg["session_id"] = self.session_id await self.websocket.send(json.dumps(self.welcome_msg)) # Load private configuration if device_id is provided @@ -148,14 +146,17 @@ async def handle_connection(self, ws): raise # 异步初始化 - self.executor.submit(self._initialize_components) - # tts 消化线程 - tts_priority = threading.Thread(target=self._tts_priority_thread, daemon=True) - tts_priority.start() + await self.loop.run_in_executor(None, self._initialize_components) - # 音频播放 消化线程 - audio_play_priority = threading.Thread(target=self._audio_play_priority_thread, daemon=True) - audio_play_priority.start() + # 启动TTS任务 + tts_task = asyncio.create_task(self._tts_priority_task()) + self.background_tasks.add(tts_task) + tts_task.add_done_callback(self.background_tasks.discard) + + # 启动音频播放任务 + audio_play_task = asyncio.create_task(self._audio_play_priority_task()) + self.background_tasks.add(audio_play_task) + audio_play_task.add_done_callback(self.background_tasks.discard) try: async for message in self.websocket: @@ -220,8 +221,7 @@ async def _check_and_broadcast_auth_code(self): # 发送验证码语音提示 text = f"请在后台输入验证码:{' '.join(auth_code)}" self.recode_first_last_text(text) - future = self.executor.submit(self.speak_and_play, text) - self.tts_queue.put(future) + await self.tts_queue.put((text, 0)) return False return True @@ -232,11 +232,10 @@ def isNeedAuth(self): return False return not self.is_device_verified - def chat(self, query): + async def chat(self, query): if self.isNeedAuth(): self.llm_finish_task = True - future = asyncio.run_coroutine_threadsafe(self._check_and_broadcast_auth_code(), self.loop) - future.result() + await self._check_and_broadcast_auth_code() return True self.dialogue.put(Message(role="user", content=query)) @@ -246,8 +245,7 @@ def chat(self, query): try: start_time = time.time() # 使用带记忆的对话 - future = asyncio.run_coroutine_threadsafe(self.memory.query_memory(query), self.loop) - memory_str = future.result() + memory_str = await self.memory.query_memory(query) self.logger.bind(tag=TAG).debug(f"记忆内容: {memory_str}") llm_responses = self.llm.response( @@ -290,8 +288,7 @@ def chat(self, query): # segment_text = " " text_index += 1 self.recode_first_last_text(segment_text, text_index) - future = self.executor.submit(self.speak_and_play, segment_text, text_index) - self.tts_queue.put(future) + await self.tts_queue.put((segment_text, text_index)) processed_chars += len(segment_text_raw) # 更新已处理字符位置 # 处理最后剩余的文本 @@ -302,21 +299,24 @@ def chat(self, query): if segment_text: text_index += 1 self.recode_first_last_text(segment_text, text_index) - future = self.executor.submit(self.speak_and_play, segment_text, text_index) - self.tts_queue.put(future) + await self.tts_queue.put((segment_text, text_index)) self.llm_finish_task = True self.dialogue.put(Message(role="assistant", content="".join(response_message))) self.logger.bind(tag=TAG).debug(json.dumps(self.dialogue.get_llm_dialogue(), indent=4, ensure_ascii=False)) return True - def chat_with_function_calling(self, query, tool_call=False): + def create_chat_task(self, query): + task = asyncio.create_task(self.chat(query)) + self.background_tasks.add(task) + task.add_done_callback(self.background_tasks.discard) + + async def chat_with_function_calling(self, query, tool_call=False): self.logger.bind(tag=TAG).debug(f"Chat with function calling start: {query}") """Chat with function calling for intent detection using streaming""" if self.isNeedAuth(): self.llm_finish_task = True - future = asyncio.run_coroutine_threadsafe(self._check_and_broadcast_auth_code(), self.loop) - future.result() + await self._check_and_broadcast_auth_code() return True if not tool_call: @@ -332,8 +332,7 @@ def chat_with_function_calling(self, query, tool_call=False): start_time = time.time() # 使用带记忆的对话 - future = asyncio.run_coroutine_threadsafe(self.memory.query_memory(query), self.loop) - memory_str = future.result() + memory_str = await self.memory.query_memory(query) # self.logger.bind(tag=TAG).info(f"对话记录: {self.dialogue.get_llm_dialogue_with_memory(memory_str)}") @@ -403,8 +402,7 @@ def chat_with_function_calling(self, query, tool_call=False): if segment_text: text_index += 1 self.recode_first_last_text(segment_text, text_index) - future = self.executor.submit(self.speak_and_play, segment_text, text_index) - self.tts_queue.put(future) + await self.tts_queue.put((segment_text, text_index)) processed_chars += len(segment_text_raw) # 更新已处理字符位置 # 处理function call @@ -437,7 +435,7 @@ def chat_with_function_calling(self, query, tool_call=False): "arguments": function_arguments } result = self.func_handler.handle_llm_function_call(self, function_call_data) - self._handle_function_result(result, function_call_data, text_index + 1) + await self._handle_function_result(result, function_call_data, text_index + 1) # 处理最后剩余的文本 full_text = "".join(response_message) @@ -447,8 +445,7 @@ def chat_with_function_calling(self, query, tool_call=False): if segment_text: text_index += 1 self.recode_first_last_text(segment_text, text_index) - future = self.executor.submit(self.speak_and_play, segment_text, text_index) - self.tts_queue.put(future) + await self.tts_queue.put((segment_text, text_index)) # 存储对话内容 if len(response_message) > 0: @@ -459,12 +456,16 @@ def chat_with_function_calling(self, query, tool_call=False): return True - def _handle_function_result(self, result, function_call_data, text_index): + def create_chat_with_function_calling_task(self, query, tool_call=False): + task = asyncio.create_task(self.chat_with_function_calling(query, tool_call)) + self.background_tasks.add(task) + task.add_done_callback(self.background_tasks.discard) + + async def _handle_function_result(self, result, function_call_data, text_index): if result.action == Action.RESPONSE: # 直接回复前端 text = result.response self.recode_first_last_text(text, text_index) - future = self.executor.submit(self.speak_and_play, text, text_index) - self.tts_queue.put(future) + await self.tts_queue.put((text, text_index)) self.dialogue.put(Message(role="assistant", content=text)) elif result.action == Action.REQLLM: # 调用函数后再请求llm生成回复 @@ -481,33 +482,32 @@ def _handle_function_result(self, result, function_call_data, text_index): "index": 0}])) self.dialogue.put(Message(role="tool", tool_call_id=function_id, content=text)) - self.chat_with_function_calling(text, tool_call=True) + await self.chat_with_function_calling(text, tool_call=True) elif result.action == Action.NOTFOUND: text = result.result self.recode_first_last_text(text, text_index) - future = self.executor.submit(self.speak_and_play, text, text_index) - self.tts_queue.put(future) + await self.tts_queue.put((text, text_index)) self.dialogue.put(Message(role="assistant", content=text)) else: text = result.result self.recode_first_last_text(text, text_index) - future = self.executor.submit(self.speak_and_play, text, text_index) - self.tts_queue.put(future) + await self.tts_queue.put((text, text_index)) self.dialogue.put(Message(role="assistant", content=text)) - def _tts_priority_thread(self): + async def _tts_priority_task(self): while not self.stop_event.is_set(): text = None try: - future = self.tts_queue.get() - if future is None: + text, text_index = await self.tts_queue.get() + if text is None: continue - text = None - opus_datas, text_index, tts_file = [], 0, None + + opus_datas, tts_file = [], None try: self.logger.bind(tag=TAG).debug("正在处理TTS任务...") tts_timeout = self.config.get("tts_timeout", 10) - tts_file, text, text_index = future.result(timeout=tts_timeout) + tts_file, text, text_index = await asyncio.wait_for(self.speak_and_play(text, text_index), timeout=tts_timeout) + if text is None or len(text) <= 0: self.logger.bind(tag=TAG).error(f"TTS出错:{text_index}: tts text is empty") elif tts_file is None: @@ -518,40 +518,40 @@ def _tts_priority_thread(self): opus_datas, duration = self.tts.audio_to_opus_data(tts_file) else: self.logger.bind(tag=TAG).error(f"TTS出错:文件不存在{tts_file}") - except TimeoutError: + except asyncio.TimeoutError: self.logger.bind(tag=TAG).error("TTS超时") except Exception as e: self.logger.bind(tag=TAG).error(f"TTS出错: {e}") + if not self.client_abort: # 如果没有中途打断就发送语音 - self.audio_play_queue.put((opus_datas, text, text_index)) + await self.audio_play_queue.put((opus_datas, text, text_index)) + if self.tts.delete_audio_file and tts_file is not None and os.path.exists(tts_file): os.remove(tts_file) except Exception as e: self.logger.bind(tag=TAG).error(f"TTS任务处理错误: {e}") self.clearSpeakStatus() - asyncio.run_coroutine_threadsafe( - self.websocket.send(json.dumps({"type": "tts", "state": "stop", "session_id": self.session_id})), - self.loop - ) - self.logger.bind(tag=TAG).error(f"tts_priority priority_thread: {text} {e}") + await self.websocket.send(json.dumps({"type": "tts", "state": "stop", "session_id": self.session_id})) + self.logger.bind(tag=TAG).error(f"tts_priority task: {text} {e}") - def _audio_play_priority_thread(self): + async def _audio_play_priority_task(self): while not self.stop_event.is_set(): text = None try: - opus_datas, text, text_index = self.audio_play_queue.get() - future = asyncio.run_coroutine_threadsafe(sendAudioMessage(self, opus_datas, text, text_index), - self.loop) - future.result() + opus_datas, text, text_index = await self.audio_play_queue.get() + await sendAudioMessage(self, opus_datas, text, text_index) except Exception as e: - self.logger.bind(tag=TAG).error(f"audio_play_priority priority_thread: {text} {e}") + self.logger.bind(tag=TAG).error(f"audio_play_priority task: {text} {e}") - def speak_and_play(self, text, text_index=0): + async def speak_and_play(self, text, text_index=0): if text is None or len(text) <= 0: self.logger.bind(tag=TAG).info(f"无需tts转换,query为空,{text}") return None, text, text_index - tts_file = self.tts.to_tts(text) + + # 使用事件循环运行同步的TTS方法 + tts_file = await self.loop.run_in_executor(None, self.tts.to_tts, text) + if tts_file is None: self.logger.bind(tag=TAG).error(f"tts转换失败,{text}") return None, text, text_index @@ -575,7 +575,6 @@ async def close(self): # 清理其他资源 self.stop_event.set() - self.executor.shutdown(wait=False) if self.websocket: await self.websocket.close() self.logger.bind(tag=TAG).info("连接资源已释放") @@ -587,13 +586,18 @@ def reset_vad_states(self): self.client_voice_stop = False self.logger.bind(tag=TAG).debug("VAD states reset.") - def chat_and_close(self, text): + async def chat_and_close(self, text): """Chat with the user and then close the connection""" try: # Use the existing chat method - self.chat(text) + await self.chat(text) # After chat is complete, close the connection self.close_after_chat = True except Exception as e: self.logger.bind(tag=TAG).error(f"Chat and close error: {str(e)}") + + def create_chat_and_close_task(self, text): + task = asyncio.create_task(self.chat_and_close(text)) + self.background_tasks.add(task) + task.add_done_callback(self.background_tasks.discard) diff --git a/main/xiaozhi-server/core/handle/intentHandler.py b/main/xiaozhi-server/core/handle/intentHandler.py index d8440fc6e..7227a8744 100644 --- a/main/xiaozhi-server/core/handle/intentHandler.py +++ b/main/xiaozhi-server/core/handle/intentHandler.py @@ -73,7 +73,7 @@ async def process_intent_result(conn, intent, original_text): logger.bind(tag=TAG).info(f"识别到退出意图: {intent}") # 如果是明确的离别意图,发送告别语并关闭连接 await send_stt_message(conn, original_text) - conn.executor.submit(conn.chat_and_close, original_text) + conn.create_chat_and_close_task(original_text) return True # 处理播放音乐意图 @@ -112,4 +112,4 @@ def extract_text_in_brackets(s): if left_bracket_index != -1 and right_bracket_index != -1 and left_bracket_index < right_bracket_index: return s[left_bracket_index + 1:right_bracket_index] else: - return "" \ No newline at end of file + return "" diff --git a/main/xiaozhi-server/core/handle/receiveAudioHandle.py b/main/xiaozhi-server/core/handle/receiveAudioHandle.py index c5b4135ca..7a11a65e9 100644 --- a/main/xiaozhi-server/core/handle/receiveAudioHandle.py +++ b/main/xiaozhi-server/core/handle/receiveAudioHandle.py @@ -57,9 +57,9 @@ async def startToChat(conn, text): await send_stt_message(conn, text) if conn.use_function_call_mode: # 使用支持function calling的聊天方法 - conn.executor.submit(conn.chat_with_function_calling, text) + conn.create_chat_with_function_calling_task(text) else: - conn.executor.submit(conn.chat, text) + conn.create_chat_task(text) async def no_voice_close_connect(conn):