diff --git a/bot/dashscope/dashscope_bot.py b/bot/dashscope/dashscope_bot.py index 07554c4d2..05939654e 100644 --- a/bot/dashscope/dashscope_bot.py +++ b/bot/dashscope/dashscope_bot.py @@ -9,6 +9,7 @@ from .dashscope_session import DashscopeSession import os import dashscope +from dashscope import Application from http import HTTPStatus @@ -19,13 +20,14 @@ "qwen-max": dashscope.Generation.Models.qwen_max, "qwen-bailian-v1": dashscope.Generation.Models.bailian_v1 } -# ZhipuAI对话模型API +# 百炼AI对话模型API class DashscopeBot(Bot): def __init__(self): super().__init__() self.sessions = SessionManager(DashscopeSession, model=conf().get("model") or "qwen-plus") self.model_name = conf().get("model") or "qwen-plus" self.api_key = conf().get("dashscope_api_key") + self.app_id = conf().get("dashscope_app_id") os.environ["DASHSCOPE_API_KEY"] = self.api_key self.client = dashscope.Generation @@ -48,10 +50,11 @@ def reply(self, query, context=None): reply = Reply(ReplyType.INFO, "配置已更新") if reply: return reply + session = self.sessions.session_query(query, session_id) logger.debug("[DASHSCOPE] session query={}".format(session.messages)) - reply_content = self.reply_text(session) + reply_content = self.reply_text(session, context=context, query=query) logger.debug( "[DASHSCOPE] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( session.messages, @@ -73,15 +76,41 @@ def reply(self, query, context=None): reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) return reply - def reply_text(self, session: DashscopeSession, retry_count=0) -> dict: + def reply_text(self, session: DashscopeSession, retry_count=0, context=None, query=None) -> dict: """ - call openai's ChatCompletion to get the answer + 支持模型对话和应用调用的统一回复方法 :param session: a conversation session - :param session_id: session id :param retry_count: retry count + :param context: 上下文参数 + :param query: 用户原始问题 :return: {} """ try: + dashscope_app_id = None + if self.app_id is not None: + dashscope_app_id = self.app_id + if dashscope_app_id: + # 应用调用模式 + dashscope.api_key = self.api_key + try: + response = Application.call( + api_key=self.api_key, + app_id=dashscope_app_id, + prompt=query or (session.messages[-1]["content"] if session.messages else "") + ) + if response.status_code == HTTPStatus.OK: + return { + "total_tokens": 0, + "completion_tokens": 1, + "content": response.output.text, + } + else: + logger.error(f"[DASHSCOPE_APP] request_id={response.request_id}, code={response.status_code}, message={response.message}") + return {"completion_tokens": 0, "content": response.message, "total_tokens": 0} + except Exception as e: + logger.exception(e) + return {"completion_tokens": 0, "content": "应用调用异常,请稍后重试", "total_tokens": 0} + # 默认走原有模型对话 dashscope.api_key = self.api_key response = self.client.call( dashscope_models[self.model_name], @@ -100,18 +129,17 @@ def reply_text(self, session: DashscopeSession, retry_count=0) -> dict: response.request_id, response.status_code, response.code, response.message )) - result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} + result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧", "total_tokens": 0} need_retry = retry_count < 2 - result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} if need_retry: - return self.reply_text(session, retry_count + 1) + return self.reply_text(session, retry_count + 1, context=context, query=query) else: return result except Exception as e: logger.exception(e) need_retry = retry_count < 2 - result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} + result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧", "total_tokens": 0} if need_retry: - return self.reply_text(session, retry_count + 1) + return self.reply_text(session, retry_count + 1, context=context, query=query) else: return result diff --git a/config.py b/config.py index a02bfae5e..ed32a6b66 100644 --- a/config.py +++ b/config.py @@ -90,6 +90,7 @@ "qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串 # 阿里灵积(通义新版sdk)模型api key "dashscope_api_key": "", + "dashscope_app_id": "", # Google Gemini Api Key "gemini_api_key": "", # wework的通用配置