Skip to content

把 threading 相关代码改成 asyncio 的形式 #501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Byte-compiled / optimized / DLL files
__pycache__/
.idea/
.vscode/
*.py[cod]
*$py.class

Expand Down
144 changes: 74 additions & 70 deletions main/xiaozhi-server/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@
import json
import uuid
import time
import queue
import copy
import asyncio
import traceback

import threading
import websockets
from typing import Dict, Any
from plugins_func.loadplugins import auto_import_modules
from config.logger import setup_logging
from core.utils.dialogue import Message, Dialogue
from core.handle.textHandle import handleTextMessage
from core.utils.util import get_string_no_punctuation_or_emoji, extract_json_from_string, get_ip_info
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from core.handle.sendAudioHandle import sendAudioMessage
from core.handle.receiveAudioHandle import handleAudioMessage
from core.handle.functionHandler import FunctionHandler
Expand Down Expand Up @@ -50,12 +48,12 @@ def __init__(self, config: Dict[str, Any], _vad, _asr, _llm, _tts, _memory, _int
self.client_abort = False
self.client_listen_mode = "auto"

# 线程任务相关
self.loop = asyncio.get_event_loop()
self.stop_event = threading.Event()
self.tts_queue = queue.Queue()
self.audio_play_queue = queue.Queue()
self.executor = ThreadPoolExecutor(max_workers=10)
# 异步任务相关
self.loop = asyncio.get_running_loop()
self.stop_event = asyncio.Event()
self.tts_queue = asyncio.Queue()
self.audio_play_queue = asyncio.Queue()
self.background_tasks = set()

# 依赖的组件
self.vad = _vad
Expand Down Expand Up @@ -117,7 +115,7 @@ async def handle_connection(self, ws):
self.websocket = ws
self.session_id = str(uuid.uuid4())

self.welcome_msg = self.config["xiaozhi"]
self.welcome_msg = copy.deepcopy(self.config["xiaozhi"])
self.welcome_msg["session_id"] = self.session_id
await self.websocket.send(json.dumps(self.welcome_msg))
# Load private configuration if device_id is provided
Expand Down Expand Up @@ -148,14 +146,17 @@ async def handle_connection(self, ws):
raise

# 异步初始化
self.executor.submit(self._initialize_components)
# tts 消化线程
tts_priority = threading.Thread(target=self._tts_priority_thread, daemon=True)
tts_priority.start()
await self.loop.run_in_executor(None, self._initialize_components)

# 音频播放 消化线程
audio_play_priority = threading.Thread(target=self._audio_play_priority_thread, daemon=True)
audio_play_priority.start()
# 启动TTS任务
tts_task = asyncio.create_task(self._tts_priority_task())
self.background_tasks.add(tts_task)
tts_task.add_done_callback(self.background_tasks.discard)

# 启动音频播放任务
audio_play_task = asyncio.create_task(self._audio_play_priority_task())
self.background_tasks.add(audio_play_task)
audio_play_task.add_done_callback(self.background_tasks.discard)

try:
async for message in self.websocket:
Expand Down Expand Up @@ -220,8 +221,7 @@ async def _check_and_broadcast_auth_code(self):
# 发送验证码语音提示
text = f"请在后台输入验证码:{' '.join(auth_code)}"
self.recode_first_last_text(text)
future = self.executor.submit(self.speak_and_play, text)
self.tts_queue.put(future)
await self.tts_queue.put((text, 0))
return False
return True

Expand All @@ -232,11 +232,10 @@ def isNeedAuth(self):
return False
return not self.is_device_verified

def chat(self, query):
async def chat(self, query):
if self.isNeedAuth():
self.llm_finish_task = True
future = asyncio.run_coroutine_threadsafe(self._check_and_broadcast_auth_code(), self.loop)
future.result()
await self._check_and_broadcast_auth_code()
return True

self.dialogue.put(Message(role="user", content=query))
Expand All @@ -246,8 +245,7 @@ def chat(self, query):
try:
start_time = time.time()
# 使用带记忆的对话
future = asyncio.run_coroutine_threadsafe(self.memory.query_memory(query), self.loop)
memory_str = future.result()
memory_str = await self.memory.query_memory(query)

self.logger.bind(tag=TAG).debug(f"记忆内容: {memory_str}")
llm_responses = self.llm.response(
Expand Down Expand Up @@ -290,8 +288,7 @@ def chat(self, query):
# segment_text = " "
text_index += 1
self.recode_first_last_text(segment_text, text_index)
future = self.executor.submit(self.speak_and_play, segment_text, text_index)
self.tts_queue.put(future)
await self.tts_queue.put((segment_text, text_index))
processed_chars += len(segment_text_raw) # 更新已处理字符位置

# 处理最后剩余的文本
Expand All @@ -302,21 +299,24 @@ def chat(self, query):
if segment_text:
text_index += 1
self.recode_first_last_text(segment_text, text_index)
future = self.executor.submit(self.speak_and_play, segment_text, text_index)
self.tts_queue.put(future)
await self.tts_queue.put((segment_text, text_index))

self.llm_finish_task = True
self.dialogue.put(Message(role="assistant", content="".join(response_message)))
self.logger.bind(tag=TAG).debug(json.dumps(self.dialogue.get_llm_dialogue(), indent=4, ensure_ascii=False))
return True

def chat_with_function_calling(self, query, tool_call=False):
def create_chat_task(self, query):
task = asyncio.create_task(self.chat(query))
self.background_tasks.add(task)
task.add_done_callback(self.background_tasks.discard)

async def chat_with_function_calling(self, query, tool_call=False):
self.logger.bind(tag=TAG).debug(f"Chat with function calling start: {query}")
"""Chat with function calling for intent detection using streaming"""
if self.isNeedAuth():
self.llm_finish_task = True
future = asyncio.run_coroutine_threadsafe(self._check_and_broadcast_auth_code(), self.loop)
future.result()
await self._check_and_broadcast_auth_code()
return True

if not tool_call:
Expand All @@ -332,8 +332,7 @@ def chat_with_function_calling(self, query, tool_call=False):
start_time = time.time()

# 使用带记忆的对话
future = asyncio.run_coroutine_threadsafe(self.memory.query_memory(query), self.loop)
memory_str = future.result()
memory_str = await self.memory.query_memory(query)

# self.logger.bind(tag=TAG).info(f"对话记录: {self.dialogue.get_llm_dialogue_with_memory(memory_str)}")

Expand Down Expand Up @@ -403,8 +402,7 @@ def chat_with_function_calling(self, query, tool_call=False):
if segment_text:
text_index += 1
self.recode_first_last_text(segment_text, text_index)
future = self.executor.submit(self.speak_and_play, segment_text, text_index)
self.tts_queue.put(future)
await self.tts_queue.put((segment_text, text_index))
processed_chars += len(segment_text_raw) # 更新已处理字符位置

# 处理function call
Expand Down Expand Up @@ -437,7 +435,7 @@ def chat_with_function_calling(self, query, tool_call=False):
"arguments": function_arguments
}
result = self.func_handler.handle_llm_function_call(self, function_call_data)
self._handle_function_result(result, function_call_data, text_index + 1)
await self._handle_function_result(result, function_call_data, text_index + 1)

# 处理最后剩余的文本
full_text = "".join(response_message)
Expand All @@ -447,8 +445,7 @@ def chat_with_function_calling(self, query, tool_call=False):
if segment_text:
text_index += 1
self.recode_first_last_text(segment_text, text_index)
future = self.executor.submit(self.speak_and_play, segment_text, text_index)
self.tts_queue.put(future)
await self.tts_queue.put((segment_text, text_index))

# 存储对话内容
if len(response_message) > 0:
Expand All @@ -459,12 +456,16 @@ def chat_with_function_calling(self, query, tool_call=False):

return True

def _handle_function_result(self, result, function_call_data, text_index):
def create_chat_with_function_calling_task(self, query, tool_call=False):
task = asyncio.create_task(self.chat_with_function_calling(query, tool_call))
self.background_tasks.add(task)
task.add_done_callback(self.background_tasks.discard)

async def _handle_function_result(self, result, function_call_data, text_index):
if result.action == Action.RESPONSE: # 直接回复前端
text = result.response
self.recode_first_last_text(text, text_index)
future = self.executor.submit(self.speak_and_play, text, text_index)
self.tts_queue.put(future)
await self.tts_queue.put((text, text_index))
self.dialogue.put(Message(role="assistant", content=text))
elif result.action == Action.REQLLM: # 调用函数后再请求llm生成回复

Expand All @@ -481,33 +482,32 @@ def _handle_function_result(self, result, function_call_data, text_index):
"index": 0}]))

self.dialogue.put(Message(role="tool", tool_call_id=function_id, content=text))
self.chat_with_function_calling(text, tool_call=True)
await self.chat_with_function_calling(text, tool_call=True)
elif result.action == Action.NOTFOUND:
text = result.result
self.recode_first_last_text(text, text_index)
future = self.executor.submit(self.speak_and_play, text, text_index)
self.tts_queue.put(future)
await self.tts_queue.put((text, text_index))
self.dialogue.put(Message(role="assistant", content=text))
else:
text = result.result
self.recode_first_last_text(text, text_index)
future = self.executor.submit(self.speak_and_play, text, text_index)
self.tts_queue.put(future)
await self.tts_queue.put((text, text_index))
self.dialogue.put(Message(role="assistant", content=text))

def _tts_priority_thread(self):
async def _tts_priority_task(self):
while not self.stop_event.is_set():
text = None
try:
future = self.tts_queue.get()
if future is None:
text, text_index = await self.tts_queue.get()
if text is None:
continue
text = None
opus_datas, text_index, tts_file = [], 0, None

opus_datas, tts_file = [], None
try:
self.logger.bind(tag=TAG).debug("正在处理TTS任务...")
tts_timeout = self.config.get("tts_timeout", 10)
tts_file, text, text_index = future.result(timeout=tts_timeout)
tts_file, text, text_index = await asyncio.wait_for(self.speak_and_play(text, text_index), timeout=tts_timeout)

if text is None or len(text) <= 0:
self.logger.bind(tag=TAG).error(f"TTS出错:{text_index}: tts text is empty")
elif tts_file is None:
Expand All @@ -518,40 +518,40 @@ def _tts_priority_thread(self):
opus_datas, duration = self.tts.audio_to_opus_data(tts_file)
else:
self.logger.bind(tag=TAG).error(f"TTS出错:文件不存在{tts_file}")
except TimeoutError:
except asyncio.TimeoutError:
self.logger.bind(tag=TAG).error("TTS超时")
except Exception as e:
self.logger.bind(tag=TAG).error(f"TTS出错: {e}")

if not self.client_abort:
# 如果没有中途打断就发送语音
self.audio_play_queue.put((opus_datas, text, text_index))
await self.audio_play_queue.put((opus_datas, text, text_index))

if self.tts.delete_audio_file and tts_file is not None and os.path.exists(tts_file):
os.remove(tts_file)
except Exception as e:
self.logger.bind(tag=TAG).error(f"TTS任务处理错误: {e}")
self.clearSpeakStatus()
asyncio.run_coroutine_threadsafe(
self.websocket.send(json.dumps({"type": "tts", "state": "stop", "session_id": self.session_id})),
self.loop
)
self.logger.bind(tag=TAG).error(f"tts_priority priority_thread: {text} {e}")
await self.websocket.send(json.dumps({"type": "tts", "state": "stop", "session_id": self.session_id}))
self.logger.bind(tag=TAG).error(f"tts_priority task: {text} {e}")

def _audio_play_priority_thread(self):
async def _audio_play_priority_task(self):
while not self.stop_event.is_set():
text = None
try:
opus_datas, text, text_index = self.audio_play_queue.get()
future = asyncio.run_coroutine_threadsafe(sendAudioMessage(self, opus_datas, text, text_index),
self.loop)
future.result()
opus_datas, text, text_index = await self.audio_play_queue.get()
await sendAudioMessage(self, opus_datas, text, text_index)
except Exception as e:
self.logger.bind(tag=TAG).error(f"audio_play_priority priority_thread: {text} {e}")
self.logger.bind(tag=TAG).error(f"audio_play_priority task: {text} {e}")

def speak_and_play(self, text, text_index=0):
async def speak_and_play(self, text, text_index=0):
if text is None or len(text) <= 0:
self.logger.bind(tag=TAG).info(f"无需tts转换,query为空,{text}")
return None, text, text_index
tts_file = self.tts.to_tts(text)

# 使用事件循环运行同步的TTS方法
tts_file = await self.loop.run_in_executor(None, self.tts.to_tts, text)

if tts_file is None:
self.logger.bind(tag=TAG).error(f"tts转换失败,{text}")
return None, text, text_index
Expand All @@ -575,7 +575,6 @@ async def close(self):

# 清理其他资源
self.stop_event.set()
self.executor.shutdown(wait=False)
if self.websocket:
await self.websocket.close()
self.logger.bind(tag=TAG).info("连接资源已释放")
Expand All @@ -587,13 +586,18 @@ def reset_vad_states(self):
self.client_voice_stop = False
self.logger.bind(tag=TAG).debug("VAD states reset.")

def chat_and_close(self, text):
async def chat_and_close(self, text):
"""Chat with the user and then close the connection"""
try:
# Use the existing chat method
self.chat(text)
await self.chat(text)

# After chat is complete, close the connection
self.close_after_chat = True
except Exception as e:
self.logger.bind(tag=TAG).error(f"Chat and close error: {str(e)}")

def create_chat_and_close_task(self, text):
task = asyncio.create_task(self.chat_and_close(text))
self.background_tasks.add(task)
task.add_done_callback(self.background_tasks.discard)
4 changes: 2 additions & 2 deletions main/xiaozhi-server/core/handle/intentHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def process_intent_result(conn, intent, original_text):
logger.bind(tag=TAG).info(f"识别到退出意图: {intent}")
# 如果是明确的离别意图,发送告别语并关闭连接
await send_stt_message(conn, original_text)
conn.executor.submit(conn.chat_and_close, original_text)
conn.create_chat_and_close_task(original_text)
return True

# 处理播放音乐意图
Expand Down Expand Up @@ -112,4 +112,4 @@ def extract_text_in_brackets(s):
if left_bracket_index != -1 and right_bracket_index != -1 and left_bracket_index < right_bracket_index:
return s[left_bracket_index + 1:right_bracket_index]
else:
return ""
return ""
4 changes: 2 additions & 2 deletions main/xiaozhi-server/core/handle/receiveAudioHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ async def startToChat(conn, text):
await send_stt_message(conn, text)
if conn.use_function_call_mode:
# 使用支持function calling的聊天方法
conn.executor.submit(conn.chat_with_function_calling, text)
conn.create_chat_with_function_calling_task(text)
else:
conn.executor.submit(conn.chat, text)
conn.create_chat_task(text)


async def no_voice_close_connect(conn):
Expand Down