diff --git a/request_llms/bridge_chatgpt.py b/request_llms/bridge_chatgpt.py index 2e4ebe3e4..858344954 100644 --- a/request_llms/bridge_chatgpt.py +++ b/request_llms/bridge_chatgpt.py @@ -91,7 +91,7 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[], try: # make a POST request to the API endpoint, stream=False from .bridge_all import model_info - endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint']) + endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint']) if not llm_kwargs['tmp_endpoint'] else llm_kwargs['tmp_endpoint'] response = requests.post(endpoint, headers=headers, proxies=proxies, json=payload, stream=True, timeout=TIMEOUT_SECONDS); break except requests.exceptions.ReadTimeout as e: @@ -147,7 +147,12 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[], raise ConnectionAbortedError("正常结束,但显示Token不足,导致输出不完整,请削减单次输入的文本量。") return result - +def is_any_tmp_model(inputs): + try: + tmp_model_info=json.loads(inputs).keys() + return "tmp_key" in tmp_model_info and "tmp_model" in tmp_model_info and "tmp_endpoint" in tmp_model_info + except: + return False def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWithCookies, history:list=[], system_prompt:str='', stream:bool=True, additional_fn:str=None): """ @@ -164,8 +169,15 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith chatbot.append(("输入已识别为openai的api_key", what_keys(inputs))) yield from update_ui(chatbot=chatbot, history=history, msg="api_key已导入") # 刷新界面 return - elif not is_any_api_key(chatbot._cookies['api_key']): - chatbot.append((inputs, "缺少api_key。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")) + elif is_any_tmp_model(inputs): + chatbot._cookies['tmp_key'] = json.loads(inputs)['tmp_key'] + chatbot._cookies['tmp_model'] = json.loads(inputs)['tmp_model'] + chatbot._cookies['tmp_endpoint'] = json.loads(inputs)['tmp_endpoint'] + chatbot.append(("输入已识别为临时openai格式的模型,页面刷新后将失效", '临时模型:'+json.loads(inputs)['tmp_model'])) + yield from update_ui(chatbot=chatbot, history=history, msg="临时模型已导入") # 刷新界面 + return + elif not is_any_api_key(chatbot._cookies['api_key']) and not chatbot._cookies['tmp_key']: + chatbot.append((inputs, '缺少API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。\n\n3.接入临时模型:在输入区键入以下格式临时模型信息{"tmp_key":"xxx","tmp_endpoint":"https://xxxx.xxx","tmp_model":"gpt-3.5-turbo-16k"},然后回车提交')) yield from update_ui(chatbot=chatbot, history=history, msg="缺少api_key") # 刷新界面 return @@ -195,7 +207,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith # 检查endpoint是否合法 try: from .bridge_all import model_info - endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint']) + endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint']) if not llm_kwargs['tmp_endpoint'] else llm_kwargs['tmp_endpoint'] except: tb_str = '```\n' + trimmed_format_exc() + '```' chatbot[-1] = (inputs, tb_str) @@ -320,11 +332,13 @@ def generate_payload(inputs, llm_kwargs, history, system_prompt, stream): """ 整合所有信息,选择LLM模型,生成http请求,为发送请求做准备 """ - if not is_any_api_key(llm_kwargs['api_key']): - raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。") + if not is_any_api_key(llm_kwargs['api_key']) and not llm_kwargs['tmp_key']: + raise AssertionError('你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。\n\n3.接入临时模型:在输入区键入以下格式临时模型信息{"tmp_key":"xxx","tmp_endpoint":"https://xxxx.xxx","tmp_model":"gpt-3.5-turbo-16k"},然后回车提交') if llm_kwargs['llm_model'].startswith('vllm-'): api_key = 'no-api-key' + elif llm_kwargs['tmp_key']: + api_key = llm_kwargs['tmp_key'] else: api_key = select_api_key(llm_kwargs['api_key'], llm_kwargs['llm_model']) @@ -383,7 +397,7 @@ def generate_payload(inputs, llm_kwargs, history, system_prompt, stream): logging.info("Random select model:" + model) payload = { - "model": model, + "model": model if not llm_kwargs['tmp_model'] else llm_kwargs['tmp_model'] , "messages": messages, "temperature": llm_kwargs['temperature'], # 1.0, "top_p": llm_kwargs['top_p'], # 1.0, diff --git a/toolbox.py b/toolbox.py index e11efc126..be921d4e6 100644 --- a/toolbox.py +++ b/toolbox.py @@ -101,12 +101,18 @@ def decorated(request: gradio.Request, cookies:dict, max_length:int, llm_model:s cookies.update({ 'top_p': top_p, 'api_key': cookies['api_key'], + 'tmp_key': cookies['tmp_key'], + 'tmp_model': cookies['tmp_model'], + 'tmp_endpoint': cookies['tmp_endpoint'], 'llm_model': llm_model, 'temperature': temperature, 'user_name': user_name, }) llm_kwargs = { 'api_key': cookies['api_key'], + 'tmp_key': cookies['tmp_key'], + 'tmp_model': cookies['tmp_model'], + 'tmp_endpoint': cookies['tmp_endpoint'], 'llm_model': llm_model, 'top_p': top_p, 'max_length': max_length, @@ -607,6 +613,9 @@ def load_chat_cookies(): "api_key": API_KEY, "llm_model": LLM_MODEL, "customize_fn_overwrite": customize_fn_overwrite_, + "tmp_key":'', + "tmp_model":'', + "tmp_endpoint":'', }