mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 05:17:21 +01:00
py : fix oai proxy (#3972)
* fix oai proxy fix generation not stoped while bot stop talking in chat mode fix possible `slot_id` not exist response for cors (and pre flight) * oai proxy: workaround for some client (such as Chatbox) * use stop as separator to replace hardcoded `\n`
This commit is contained in:
parent
1f5cd83275
commit
e2bd725f4b
@ -11,10 +11,10 @@ app = Flask(__name__)
|
|||||||
slot_id = -1
|
slot_id = -1
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
|
parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
|
||||||
parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')
|
parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')
|
||||||
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
|
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: 'USER: ')", default="USER: ")
|
||||||
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
|
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: 'ASSISTANT: ')", default="ASSISTANT: ")
|
||||||
parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ")
|
parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: 'ASSISTANT's RULE: ')", default="ASSISTANT's RULE: ")
|
||||||
parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>")
|
parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>")
|
||||||
parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080')
|
parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080')
|
||||||
parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="")
|
parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="")
|
||||||
@ -34,19 +34,19 @@ def is_present(json, key):
|
|||||||
|
|
||||||
#convert chat to prompt
|
#convert chat to prompt
|
||||||
def convert_chat(messages):
|
def convert_chat(messages):
|
||||||
prompt = "" + args.chat_prompt.replace("\\n", "\n")
|
|
||||||
|
|
||||||
system_n = args.system_name.replace("\\n", "\n")
|
system_n = args.system_name
|
||||||
user_n = args.user_name.replace("\\n", "\n")
|
user_n = args.user_name
|
||||||
ai_n = args.ai_name.replace("\\n", "\n")
|
ai_n = args.ai_name
|
||||||
stop = args.stop.replace("\\n", "\n")
|
stop = args.stop
|
||||||
|
|
||||||
|
prompt = "" + args.chat_prompt + stop
|
||||||
|
|
||||||
for line in messages:
|
for line in messages:
|
||||||
if (line["role"] == "system"):
|
if (line["role"] == "system"):
|
||||||
prompt += f"{system_n}{line['content']}"
|
prompt += f"{system_n}{line['content']}{stop}"
|
||||||
if (line["role"] == "user"):
|
if (line["role"] == "user"):
|
||||||
prompt += f"{user_n}{line['content']}"
|
prompt += f"{user_n}{line['content']}{stop}"
|
||||||
if (line["role"] == "assistant"):
|
if (line["role"] == "assistant"):
|
||||||
prompt += f"{ai_n}{line['content']}{stop}"
|
prompt += f"{ai_n}{line['content']}{stop}"
|
||||||
prompt += ai_n.rstrip()
|
prompt += ai_n.rstrip()
|
||||||
@ -130,7 +130,7 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
slot_id = data["slot_id"]
|
slot_id = data.get("slot_id")
|
||||||
if (chat):
|
if (chat):
|
||||||
if (start):
|
if (start):
|
||||||
resData["choices"][0]["delta"] = {
|
resData["choices"][0]["delta"] = {
|
||||||
@ -150,11 +150,13 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False):
|
|||||||
return resData
|
return resData
|
||||||
|
|
||||||
|
|
||||||
@app.route('/chat/completions', methods=['POST'])
|
@app.route('/chat/completions', methods=['POST', 'OPTIONS'])
|
||||||
@app.route('/v1/chat/completions', methods=['POST'])
|
@app.route('/v1/chat/completions', methods=['POST', 'OPTIONS'])
|
||||||
def chat_completions():
|
def chat_completions():
|
||||||
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
|
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
|
||||||
return Response(status=403)
|
return Response(status=403)
|
||||||
|
if request.method == 'OPTIONS':
|
||||||
|
return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
|
||||||
body = request.get_json()
|
body = request.get_json()
|
||||||
stream = False
|
stream = False
|
||||||
tokenize = False
|
tokenize = False
|
||||||
@ -177,20 +179,22 @@ def chat_completions():
|
|||||||
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
|
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
|
||||||
time_now = int(time.time())
|
time_now = int(time.time())
|
||||||
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
|
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
|
||||||
yield 'data: {}\n'.format(json.dumps(resData))
|
yield 'data: {}\n\n'.format(json.dumps(resData))
|
||||||
for line in data.iter_lines():
|
for line in data.iter_lines():
|
||||||
if line:
|
if line:
|
||||||
decoded_line = line.decode('utf-8')
|
decoded_line = line.decode('utf-8')
|
||||||
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
|
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
|
||||||
yield 'data: {}\n'.format(json.dumps(resData))
|
yield 'data: {}\n\n'.format(json.dumps(resData))
|
||||||
return Response(generate(), mimetype='text/event-stream')
|
return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
|
||||||
|
|
||||||
|
|
||||||
@app.route('/completions', methods=['POST'])
|
@app.route('/completions', methods=['POST', 'OPTIONS'])
|
||||||
@app.route('/v1/completions', methods=['POST'])
|
@app.route('/v1/completions', methods=['POST', 'OPTIONS'])
|
||||||
def completion():
|
def completion():
|
||||||
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
|
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
|
||||||
return Response(status=403)
|
return Response(status=403)
|
||||||
|
if request.method == 'OPTIONS':
|
||||||
|
return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
|
||||||
body = request.get_json()
|
body = request.get_json()
|
||||||
stream = False
|
stream = False
|
||||||
tokenize = False
|
tokenize = False
|
||||||
@ -216,8 +220,8 @@ def completion():
|
|||||||
if line:
|
if line:
|
||||||
decoded_line = line.decode('utf-8')
|
decoded_line = line.decode('utf-8')
|
||||||
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
|
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
|
||||||
yield 'data: {}\n'.format(json.dumps(resData))
|
yield 'data: {}\n\n'.format(json.dumps(resData))
|
||||||
return Response(generate(), mimetype='text/event-stream')
|
return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(args.host, port=args.port)
|
app.run(args.host, port=args.port)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user