mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Style/pep8 improvements
This commit is contained in:
parent
ecd79caa68
commit
320fcfde4e
@ -1,8 +1,9 @@
|
|||||||
import json, time, os
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.text_generation import encode, generate_reply
|
from modules.text_generation import encode, generate_reply
|
||||||
|
|
||||||
@ -25,13 +26,15 @@ embedding_model = None
|
|||||||
standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ]
|
standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ]
|
||||||
|
|
||||||
# little helper to get defaults if arg is present but None and should be the same type as default.
|
# little helper to get defaults if arg is present but None and should be the same type as default.
|
||||||
|
|
||||||
|
|
||||||
def default(dic, key, default):
|
def default(dic, key, default):
|
||||||
val = dic.get(key, default)
|
val = dic.get(key, default)
|
||||||
if type(val) != type(default):
|
if type(val) != type(default):
|
||||||
# maybe it's just something like 1 instead of 1.0
|
# maybe it's just something like 1 instead of 1.0
|
||||||
try:
|
try:
|
||||||
v = type(default)(val)
|
v = type(default)(val)
|
||||||
if type(val)(v) == val: # if it's the same value passed in, it's ok.
|
if type(val)(v) == val: # if it's the same value passed in, it's ok.
|
||||||
return v
|
return v
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@ -39,6 +42,7 @@ def default(dic, key, default):
|
|||||||
val = default
|
val = default
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
def clamp(value, minvalue, maxvalue):
|
def clamp(value, minvalue, maxvalue):
|
||||||
return max(minvalue, min(value, maxvalue))
|
return max(minvalue, min(value, maxvalue))
|
||||||
|
|
||||||
@ -54,27 +58,27 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
# TODO: list all models and allow model changes via API? Lora's?
|
# TODO: list all models and allow model changes via API? Lora's?
|
||||||
# This API should list capabilities, limits and pricing...
|
# This API should list capabilities, limits and pricing...
|
||||||
models = [{
|
models = [{
|
||||||
"id": shared.model_name, # The real chat/completions model
|
"id": shared.model_name, # The real chat/completions model
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"owned_by": "user",
|
"owned_by": "user",
|
||||||
"permission": []
|
"permission": []
|
||||||
}, {
|
}, {
|
||||||
"id": st_model, # The real sentence transformer embeddings model
|
"id": st_model, # The real sentence transformer embeddings model
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"owned_by": "user",
|
"owned_by": "user",
|
||||||
"permission": []
|
"permission": []
|
||||||
}, { # these are expected by so much, so include some here as a dummy
|
}, { # these are expected by so much, so include some here as a dummy
|
||||||
"id": "gpt-3.5-turbo", # /v1/chat/completions
|
"id": "gpt-3.5-turbo", # /v1/chat/completions
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"owned_by": "user",
|
"owned_by": "user",
|
||||||
"permission": []
|
"permission": []
|
||||||
}, {
|
}, {
|
||||||
"id": "text-curie-001", # /v1/completions, 2k context
|
"id": "text-curie-001", # /v1/completions, 2k context
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"owned_by": "user",
|
"owned_by": "user",
|
||||||
"permission": []
|
"permission": []
|
||||||
}, {
|
}, {
|
||||||
"id": "text-davinci-002", # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
|
"id": "text-davinci-002", # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"owned_by": "user",
|
"owned_by": "user",
|
||||||
"permission": []
|
"permission": []
|
||||||
@ -103,8 +107,10 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
content_length = int(self.headers['Content-Length'])
|
content_length = int(self.headers['Content-Length'])
|
||||||
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
||||||
|
|
||||||
if debug: print(self.headers) # did you know... python-openai sends your linux kernel & python version?
|
if debug:
|
||||||
if debug: print(body)
|
print(self.headers) # did you know... python-openai sends your linux kernel & python version?
|
||||||
|
if debug:
|
||||||
|
print(body)
|
||||||
|
|
||||||
if '/completions' in self.path or '/generate' in self.path:
|
if '/completions' in self.path or '/generate' in self.path:
|
||||||
is_legacy = '/generate' in self.path
|
is_legacy = '/generate' in self.path
|
||||||
@ -112,7 +118,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
resp_list = 'data' if is_legacy else 'choices'
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
|
|
||||||
# XXX model is ignored for now
|
# XXX model is ignored for now
|
||||||
#model = body.get('model', shared.model_name) # ignored, use existing for now
|
# model = body.get('model', shared.model_name) # ignored, use existing for now
|
||||||
model = shared.model_name
|
model = shared.model_name
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
cmpl_id = "conv-%d" % (created_time)
|
cmpl_id = "conv-%d" % (created_time)
|
||||||
@ -129,11 +135,11 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
||||||
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
|
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
|
||||||
|
|
||||||
default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it., the default for chat is "inf"
|
default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it., the default for chat is "inf"
|
||||||
|
|
||||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||||
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
||||||
|
|
||||||
# hard scale this, assuming the given max is for GPT3/4, perhaps inspect the requested model and lookup the context max
|
# hard scale this, assuming the given max is for GPT3/4, perhaps inspect the requested model and lookup the context max
|
||||||
while truncation_length <= max_tokens:
|
while truncation_length <= max_tokens:
|
||||||
max_tokens = max_tokens // 2
|
max_tokens = max_tokens // 2
|
||||||
@ -143,17 +149,17 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
'temperature': default(body, 'temperature', 1.0),
|
'temperature': default(body, 'temperature', 1.0),
|
||||||
'top_p': default(body, 'top_p', 1.0),
|
'top_p': default(body, 'top_p', 1.0),
|
||||||
'top_k': default(body, 'best_of', 1),
|
'top_k': default(body, 'best_of', 1),
|
||||||
### XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2
|
# XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2
|
||||||
# 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it.
|
# 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it.
|
||||||
'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better.
|
'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better.
|
||||||
### XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it.
|
# XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it.
|
||||||
'encoder_repetition_penalty': 1.0, #(default(body, 'frequency_penalty', 0) + 2.0) / 2.0,
|
'encoder_repetition_penalty': 1.0, # (default(body, 'frequency_penalty', 0) + 2.0) / 2.0,
|
||||||
'suffix': body.get('suffix', None),
|
'suffix': body.get('suffix', None),
|
||||||
'stream': default(body, 'stream', False),
|
'stream': default(body, 'stream', False),
|
||||||
'echo': default(body, 'echo', False),
|
'echo': default(body, 'echo', False),
|
||||||
#####################################################
|
#####################################################
|
||||||
'seed': shared.settings.get('seed', -1),
|
'seed': shared.settings.get('seed', -1),
|
||||||
#int(body.get('n', 1)) # perhaps this should be num_beams or chat_generation_attempts? 'n' doesn't have a direct map
|
# int(body.get('n', 1)) # perhaps this should be num_beams or chat_generation_attempts? 'n' doesn't have a direct map
|
||||||
# unofficial, but it needs to get set anyways.
|
# unofficial, but it needs to get set anyways.
|
||||||
'truncation_length': truncation_length,
|
'truncation_length': truncation_length,
|
||||||
# no more args.
|
# no more args.
|
||||||
@ -178,7 +184,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if req_params['stream']:
|
if req_params['stream']:
|
||||||
self.send_header('Content-Type', 'text/event-stream')
|
self.send_header('Content-Type', 'text/event-stream')
|
||||||
self.send_header('Cache-Control', 'no-cache')
|
self.send_header('Cache-Control', 'no-cache')
|
||||||
#self.send_header('Connection', 'keep-alive')
|
# self.send_header('Connection', 'keep-alive')
|
||||||
else:
|
else:
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header('Content-Type', 'application/json')
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
@ -195,8 +201,8 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
messages = body['messages']
|
messages = body['messages']
|
||||||
|
|
||||||
system_msg = '' # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
|
system_msg = '' # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
|
||||||
if 'prompt' in body: # Maybe they sent both? This is not documented in the API, but some clients seem to do this.
|
if 'prompt' in body: # Maybe they sent both? This is not documented in the API, but some clients seem to do this.
|
||||||
system_msg = body['prompt']
|
system_msg = body['prompt']
|
||||||
|
|
||||||
chat_msgs = []
|
chat_msgs = []
|
||||||
@ -204,16 +210,16 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
for m in messages:
|
for m in messages:
|
||||||
role = m['role']
|
role = m['role']
|
||||||
content = m['content']
|
content = m['content']
|
||||||
#name = m.get('name', 'user')
|
# name = m.get('name', 'user')
|
||||||
if role == 'system':
|
if role == 'system':
|
||||||
system_msg += content
|
system_msg += content
|
||||||
else:
|
else:
|
||||||
chat_msgs.extend([f"\n{role}: {content.strip()}"]) ### Strip content? linefeed?
|
chat_msgs.extend([f"\n{role}: {content.strip()}"]) # Strip content? linefeed?
|
||||||
|
|
||||||
system_token_count = len(encode(system_msg)[0])
|
system_token_count = len(encode(system_msg)[0])
|
||||||
remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count
|
remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count
|
||||||
chat_msg = ''
|
chat_msg = ''
|
||||||
|
|
||||||
while chat_msgs:
|
while chat_msgs:
|
||||||
new_msg = chat_msgs.pop()
|
new_msg = chat_msgs.pop()
|
||||||
new_size = len(encode(new_msg)[0])
|
new_size = len(encode(new_msg)[0])
|
||||||
@ -229,7 +235,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
print(f"truncating chat messages, dropping {len(chat_msgs)} messages.")
|
print(f"truncating chat messages, dropping {len(chat_msgs)} messages.")
|
||||||
|
|
||||||
if system_msg:
|
if system_msg:
|
||||||
prompt = 'system: ' + system_msg + '\n' + chat_msg + '\nassistant: '
|
prompt = 'system: ' + system_msg + '\n' + chat_msg + '\nassistant: '
|
||||||
else:
|
else:
|
||||||
prompt = chat_msg + '\nassistant: '
|
prompt = chat_msg + '\nassistant: '
|
||||||
|
|
||||||
@ -245,16 +251,16 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||||
if is_legacy:
|
if is_legacy:
|
||||||
prompt = body['context'] # Older engines.generate API
|
prompt = body['context'] # Older engines.generate API
|
||||||
else:
|
else:
|
||||||
prompt = body['prompt'] # XXX this can be different types
|
prompt = body['prompt'] # XXX this can be different types
|
||||||
|
|
||||||
if isinstance(prompt, list):
|
if isinstance(prompt, list):
|
||||||
prompt = ''.join(prompt) # XXX this is wrong... need to split out to multiple calls?
|
prompt = ''.join(prompt) # XXX this is wrong... need to split out to multiple calls?
|
||||||
|
|
||||||
token_count = len(encode(prompt)[0])
|
token_count = len(encode(prompt)[0])
|
||||||
if token_count >= req_params['truncation_length']:
|
if token_count >= req_params['truncation_length']:
|
||||||
new_len = int(len(prompt) * (float(shared.settings['truncation_length']) - req_params['max_new_tokens']) / token_count)
|
new_len = int(len(prompt) * (float(shared.settings['truncation_length']) - req_params['max_new_tokens']) / token_count)
|
||||||
prompt = prompt[-new_len:]
|
prompt = prompt[-new_len:]
|
||||||
print(f"truncating prompt to {new_len} characters, was {token_count} tokens. Now: {len(encode(prompt)[0])} tokens.")
|
print(f"truncating prompt to {new_len} characters, was {token_count} tokens. Now: {len(encode(prompt)[0])} tokens.")
|
||||||
|
|
||||||
@ -262,7 +268,6 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
# some strange cases of "##| Instruction: " sneaking through.
|
# some strange cases of "##| Instruction: " sneaking through.
|
||||||
stopping_strings += standard_stopping_strings
|
stopping_strings += standard_stopping_strings
|
||||||
req_params['custom_stopping_strings'] = stopping_strings
|
req_params['custom_stopping_strings'] = stopping_strings
|
||||||
|
|
||||||
|
|
||||||
shared.args.no_stream = not req_params['stream']
|
shared.args.no_stream = not req_params['stream']
|
||||||
if not shared.args.no_stream:
|
if not shared.args.no_stream:
|
||||||
@ -283,22 +288,23 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
chunk[resp_list][0]["text"] = ""
|
chunk[resp_list][0]["text"] = ""
|
||||||
else:
|
else:
|
||||||
# This is coming back as "system" to the openapi cli, not sure why.
|
# This is coming back as "system" to the openapi cli, not sure why.
|
||||||
# So yeah... do both methods? delta and messages.
|
# So yeah... do both methods? delta and messages.
|
||||||
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
|
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
|
||||||
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
|
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
|
||||||
#{ "role": "assistant" }
|
# { "role": "assistant" }
|
||||||
|
|
||||||
response = 'data: ' + json.dumps(chunk) + '\n'
|
response = 'data: ' + json.dumps(chunk) + '\n'
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
# generate reply #######################################
|
# generate reply #######################################
|
||||||
if debug: print ({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
|
if debug:
|
||||||
|
print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
|
||||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings)
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
seen_content = ''
|
seen_content = ''
|
||||||
longest_stop_len = max([ len(x) for x in stopping_strings ])
|
longest_stop_len = max([len(x) for x in stopping_strings])
|
||||||
|
|
||||||
for a in generator:
|
for a in generator:
|
||||||
if isinstance(a, str):
|
if isinstance(a, str):
|
||||||
answer = a
|
answer = a
|
||||||
@ -312,7 +318,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
for string in stopping_strings:
|
for string in stopping_strings:
|
||||||
idx = answer.find(string, search_start)
|
idx = answer.find(string, search_start)
|
||||||
if idx != -1:
|
if idx != -1:
|
||||||
answer = answer[:idx] # clip it.
|
answer = answer[:idx] # clip it.
|
||||||
stop_string_found = True
|
stop_string_found = True
|
||||||
|
|
||||||
if stop_string_found:
|
if stop_string_found:
|
||||||
@ -338,9 +344,9 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
# Streaming
|
# Streaming
|
||||||
new_content = answer[len_seen:]
|
new_content = answer[len_seen:]
|
||||||
|
|
||||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
seen_content = answer
|
seen_content = answer
|
||||||
chunk = {
|
chunk = {
|
||||||
"id": cmpl_id,
|
"id": cmpl_id,
|
||||||
@ -355,9 +361,9 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if stream_object_type == 'text_completion.chunk':
|
if stream_object_type == 'text_completion.chunk':
|
||||||
chunk[resp_list][0]['text'] = new_content
|
chunk[resp_list][0]['text'] = new_content
|
||||||
else:
|
else:
|
||||||
# So yeah... do both methods? delta and messages.
|
# So yeah... do both methods? delta and messages.
|
||||||
chunk[resp_list][0]['message'] = { 'content': new_content }
|
chunk[resp_list][0]['message'] = {'content': new_content}
|
||||||
chunk[resp_list][0]['delta'] = { 'content': new_content }
|
chunk[resp_list][0]['delta'] = {'content': new_content}
|
||||||
response = 'data: ' + json.dumps(chunk) + '\n'
|
response = 'data: ' + json.dumps(chunk) + '\n'
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
completion_token_count += len(encode(new_content)[0])
|
completion_token_count += len(encode(new_content)[0])
|
||||||
@ -367,7 +373,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
"id": cmpl_id,
|
"id": cmpl_id,
|
||||||
"object": stream_object_type,
|
"object": stream_object_type,
|
||||||
"created": created_time,
|
"created": created_time,
|
||||||
"model": model, # TODO: add Lora info?
|
"model": model, # TODO: add Lora info?
|
||||||
resp_list: [{
|
resp_list: [{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": "stop",
|
"finish_reason": "stop",
|
||||||
@ -381,16 +387,18 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if stream_object_type == 'text_completion.chunk':
|
if stream_object_type == 'text_completion.chunk':
|
||||||
chunk[resp_list][0]['text'] = ''
|
chunk[resp_list][0]['text'] = ''
|
||||||
else:
|
else:
|
||||||
# So yeah... do both methods? delta and messages.
|
# So yeah... do both methods? delta and messages.
|
||||||
chunk[resp_list][0]['message'] = {'content': '' }
|
chunk[resp_list][0]['message'] = {'content': ''}
|
||||||
chunk[resp_list][0]['delta'] = {}
|
chunk[resp_list][0]['delta'] = {}
|
||||||
response = 'data: ' + json.dumps(chunk) + '\ndata: [DONE]\n'
|
response = 'data: ' + json.dumps(chunk) + '\ndata: [DONE]\n'
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
###### Finished if streaming.
|
# Finished if streaming.
|
||||||
if debug: print({'response': answer})
|
if debug:
|
||||||
|
print({'response': answer})
|
||||||
return
|
return
|
||||||
|
|
||||||
if debug: print({'response': answer})
|
if debug:
|
||||||
|
print({'response': answer})
|
||||||
|
|
||||||
completion_token_count = len(encode(answer)[0])
|
completion_token_count = len(encode(answer)[0])
|
||||||
stop_reason = "stop"
|
stop_reason = "stop"
|
||||||
@ -401,7 +409,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
"id": cmpl_id,
|
"id": cmpl_id,
|
||||||
"object": object_type,
|
"object": object_type,
|
||||||
"created": created_time,
|
"created": created_time,
|
||||||
"model": model, # TODO: add Lora info?
|
"model": model, # TODO: add Lora info?
|
||||||
resp_list: [{
|
resp_list: [{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": stop_reason,
|
"finish_reason": stop_reason,
|
||||||
@ -414,13 +422,13 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if is_chat:
|
if is_chat:
|
||||||
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer }
|
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer}
|
||||||
else:
|
else:
|
||||||
resp[resp_list][0]["text"] = answer
|
resp[resp_list][0]["text"] = answer
|
||||||
|
|
||||||
response = json.dumps(resp)
|
response = json.dumps(resp)
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
elif '/embeddings' in self.path and embedding_model != None:
|
elif '/embeddings' in self.path and embedding_model is not None:
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header('Content-Type', 'application/json')
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
@ -431,19 +439,20 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
embeddings = embedding_model.encode(input).tolist()
|
embeddings = embedding_model.encode(input).tolist()
|
||||||
|
|
||||||
data = [ {"object": "embedding", "embedding": emb, "index": n } for n, emb in enumerate(embeddings) ]
|
data = [{"object": "embedding", "embedding": emb, "index": n} for n, emb in enumerate(embeddings)]
|
||||||
|
|
||||||
response = json.dumps({
|
response = json.dumps({
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": data,
|
"data": data,
|
||||||
"model": st_model, # return the real model
|
"model": st_model, # return the real model
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": 0,
|
"prompt_tokens": 0,
|
||||||
"total_tokens": 0,
|
"total_tokens": 0,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
if debug: print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
if debug:
|
||||||
|
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
elif '/moderations' in self.path:
|
elif '/moderations' in self.path:
|
||||||
# for now do nothing, just don't error.
|
# for now do nothing, just don't error.
|
||||||
@ -521,4 +530,3 @@ def run_server():
|
|||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
Thread(target=run_server, daemon=True).start()
|
Thread(target=run_server, daemon=True).start()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user