mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
[extensions/openai] Support undocumented base64 'encoding_format' param for compatibility with official OpenAI client (#1876)
This commit is contained in:
parent
d78b04f0b4
commit
791a38bad1
@ -1,4 +1,6 @@
|
||||
import base64
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
@ -45,6 +47,20 @@ def clamp(value, minvalue, maxvalue):
|
||||
return max(minvalue, min(value, maxvalue))
|
||||
|
||||
|
||||
def float_list_to_base64(float_list):
|
||||
# Convert the list to a float32 array that the OpenAPI client expects
|
||||
float_array = np.array(float_list, dtype="float32")
|
||||
|
||||
# Get raw bytes
|
||||
bytes_array = float_array.tobytes()
|
||||
|
||||
# Encode bytes into base64
|
||||
encoded_bytes = base64.b64encode(bytes_array)
|
||||
|
||||
# Turn raw base64 encoded bytes into ASCII
|
||||
ascii_string = encoded_bytes.decode('ascii')
|
||||
return ascii_string
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path.startswith('/v1/models'):
|
||||
@ -435,7 +451,13 @@ class Handler(BaseHTTPRequestHandler):
|
||||
|
||||
embeddings = embedding_model.encode(input).tolist()
|
||||
|
||||
data = [{"object": "embedding", "embedding": emb, "index": n} for n, emb in enumerate(embeddings)]
|
||||
def enc_emb(emb):
|
||||
# If base64 is specified, encode. Otherwise, do nothing.
|
||||
if body.get("encoding_format", "") == "base64":
|
||||
return float_list_to_base64(emb)
|
||||
else:
|
||||
return emb
|
||||
data = [{"object": "embedding", "embedding": enc_emb(emb), "index": n} for n, emb in enumerate(embeddings)]
|
||||
|
||||
response = json.dumps({
|
||||
"object": "list",
|
||||
|
Loading…
Reference in New Issue
Block a user