Skip to content
30 changes: 22 additions & 8 deletions request_llms/bridge_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
try:
# make a POST request to the API endpoint, stream=False
from .bridge_all import model_info
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint']) if not llm_kwargs['tmp_endpoint'] else llm_kwargs['tmp_endpoint']
response = requests.post(endpoint, headers=headers, proxies=proxies,
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
except requests.exceptions.ReadTimeout as e:
Expand Down Expand Up @@ -147,7 +147,12 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
raise ConnectionAbortedError("正常结束,但显示Token不足,导致输出不完整,请削减单次输入的文本量。")
return result


def is_any_tmp_model(inputs):
try:
tmp_model_info=json.loads(inputs).keys()
return "tmp_key" in tmp_model_info and "tmp_model" in tmp_model_info and "tmp_endpoint" in tmp_model_info
except:
return False
def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWithCookies,
history:list=[], system_prompt:str='', stream:bool=True, additional_fn:str=None):
"""
Expand All @@ -164,8 +169,15 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
chatbot.append(("输入已识别为openai的api_key", what_keys(inputs)))
yield from update_ui(chatbot=chatbot, history=history, msg="api_key已导入") # 刷新界面
return
elif not is_any_api_key(chatbot._cookies['api_key']):
chatbot.append((inputs, "缺少api_key。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。"))
elif is_any_tmp_model(inputs):
chatbot._cookies['tmp_key'] = json.loads(inputs)['tmp_key']
chatbot._cookies['tmp_model'] = json.loads(inputs)['tmp_model']
chatbot._cookies['tmp_endpoint'] = json.loads(inputs)['tmp_endpoint']
chatbot.append(("输入已识别为临时openai格式的模型,页面刷新后将失效", '临时模型:'+json.loads(inputs)['tmp_model']))
yield from update_ui(chatbot=chatbot, history=history, msg="临时模型已导入") # 刷新界面
return
elif not is_any_api_key(chatbot._cookies['api_key']) and not chatbot._cookies['tmp_key']:
chatbot.append((inputs, '缺少API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。\n\n3.接入临时模型:在输入区键入以下格式临时模型信息{"tmp_key":"xxx","tmp_endpoint":"https://xxxx.xxx","tmp_model":"gpt-3.5-turbo-16k"},然后回车提交'))
yield from update_ui(chatbot=chatbot, history=history, msg="缺少api_key") # 刷新界面
return

Expand Down Expand Up @@ -195,7 +207,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
# 检查endpoint是否合法
try:
from .bridge_all import model_info
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint']) if not llm_kwargs['tmp_endpoint'] else llm_kwargs['tmp_endpoint']
except:
tb_str = '```\n' + trimmed_format_exc() + '```'
chatbot[-1] = (inputs, tb_str)
Expand Down Expand Up @@ -320,11 +332,13 @@ def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
"""
整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
"""
if not is_any_api_key(llm_kwargs['api_key']):
raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")
if not is_any_api_key(llm_kwargs['api_key']) and not llm_kwargs['tmp_key']:
raise AssertionError('你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。\n\n3.接入临时模型:在输入区键入以下格式临时模型信息{"tmp_key":"xxx","tmp_endpoint":"https://xxxx.xxx","tmp_model":"gpt-3.5-turbo-16k"},然后回车提交')

if llm_kwargs['llm_model'].startswith('vllm-'):
api_key = 'no-api-key'
elif llm_kwargs['tmp_key']:
api_key = llm_kwargs['tmp_key']
else:
api_key = select_api_key(llm_kwargs['api_key'], llm_kwargs['llm_model'])

Expand Down Expand Up @@ -383,7 +397,7 @@ def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
logging.info("Random select model:" + model)

payload = {
"model": model,
"model": model if not llm_kwargs['tmp_model'] else llm_kwargs['tmp_model'] ,
"messages": messages,
"temperature": llm_kwargs['temperature'], # 1.0,
"top_p": llm_kwargs['top_p'], # 1.0,
Expand Down
9 changes: 9 additions & 0 deletions toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,18 @@ def decorated(request: gradio.Request, cookies:dict, max_length:int, llm_model:s
cookies.update({
'top_p': top_p,
'api_key': cookies['api_key'],
'tmp_key': cookies['tmp_key'],
'tmp_model': cookies['tmp_model'],
'tmp_endpoint': cookies['tmp_endpoint'],
'llm_model': llm_model,
'temperature': temperature,
'user_name': user_name,
})
llm_kwargs = {
'api_key': cookies['api_key'],
'tmp_key': cookies['tmp_key'],
'tmp_model': cookies['tmp_model'],
'tmp_endpoint': cookies['tmp_endpoint'],
'llm_model': llm_model,
'top_p': top_p,
'max_length': max_length,
Expand Down Expand Up @@ -607,6 +613,9 @@ def load_chat_cookies():
"api_key": API_KEY,
"llm_model": LLM_MODEL,
"customize_fn_overwrite": customize_fn_overwrite_,
"tmp_key":'',
"tmp_model":'',
"tmp_endpoint":'',
}


Expand Down