mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
[extensions/openai] Support undocumented base64 'encoding_format' param for compatibility with official OpenAI client (#1876)
This commit is contained in:
parent
d78b04f0b4
commit
791a38bad1
@ -1,4 +1,6 @@
|
|||||||
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
@ -45,6 +47,20 @@ def clamp(value, minvalue, maxvalue):
|
|||||||
return max(minvalue, min(value, maxvalue))
|
return max(minvalue, min(value, maxvalue))
|
||||||
|
|
||||||
|
|
||||||
|
def float_list_to_base64(float_list):
|
||||||
|
# Convert the list to a float32 array that the OpenAPI client expects
|
||||||
|
float_array = np.array(float_list, dtype="float32")
|
||||||
|
|
||||||
|
# Get raw bytes
|
||||||
|
bytes_array = float_array.tobytes()
|
||||||
|
|
||||||
|
# Encode bytes into base64
|
||||||
|
encoded_bytes = base64.b64encode(bytes_array)
|
||||||
|
|
||||||
|
# Turn raw base64 encoded bytes into ASCII
|
||||||
|
ascii_string = encoded_bytes.decode('ascii')
|
||||||
|
return ascii_string
|
||||||
|
|
||||||
class Handler(BaseHTTPRequestHandler):
|
class Handler(BaseHTTPRequestHandler):
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
if self.path.startswith('/v1/models'):
|
if self.path.startswith('/v1/models'):
|
||||||
@ -435,7 +451,13 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
embeddings = embedding_model.encode(input).tolist()
|
embeddings = embedding_model.encode(input).tolist()
|
||||||
|
|
||||||
data = [{"object": "embedding", "embedding": emb, "index": n} for n, emb in enumerate(embeddings)]
|
def enc_emb(emb):
|
||||||
|
# If base64 is specified, encode. Otherwise, do nothing.
|
||||||
|
if body.get("encoding_format", "") == "base64":
|
||||||
|
return float_list_to_base64(emb)
|
||||||
|
else:
|
||||||
|
return emb
|
||||||
|
data = [{"object": "embedding", "embedding": enc_emb(emb), "index": n} for n, emb in enumerate(embeddings)]
|
||||||
|
|
||||||
response = json.dumps({
|
response = json.dumps({
|
||||||
"object": "list",
|
"object": "list",
|
||||||
|
Loading…
Reference in New Issue
Block a user