mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-30 06:00:15 +01:00
Make /v1/embeddings functional, add request/response types
This commit is contained in:
parent
7ed2143cd6
commit
c5be3f7acb
@ -211,7 +211,7 @@ The following environment variables can be used (they take precendence over ever
|
|||||||
| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem |
|
| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem |
|
||||||
| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 |
|
| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 |
|
||||||
| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 |
|
| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 |
|
||||||
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | all-mpnet-base-v2 |
|
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | sentence-transformers/all-mpnet-base-v2 |
|
||||||
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
|
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
|
||||||
|
|
||||||
#### Persistent settings with `settings.yaml`
|
#### Persistent settings with `settings.yaml`
|
||||||
@ -220,7 +220,7 @@ You can also set the following variables in your `settings.yaml` file:
|
|||||||
|
|
||||||
```
|
```
|
||||||
openai-embedding_device: cuda
|
openai-embedding_device: cuda
|
||||||
openai-embedding_model: all-mpnet-base-v2
|
openai-embedding_model: "sentence-transformers/all-mpnet-base-v2"
|
||||||
openai-sd_webui_url: http://127.0.0.1:7861
|
openai-sd_webui_url: http://127.0.0.1:7861
|
||||||
openai-debug: 1
|
openai-debug: 1
|
||||||
```
|
```
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# preload the embedding model, useful for Docker images to prevent re-download on config change
|
# preload the embedding model, useful for Docker images to prevent re-download on config change
|
||||||
# Dockerfile:
|
# Dockerfile:
|
||||||
# ENV OPENEDAI_EMBEDDING_MODEL=all-mpnet-base-v2 # Optional
|
# ENV OPENEDAI_EMBEDDING_MODEL="sentence-transformers/all-mpnet-base-v2" # Optional
|
||||||
# RUN python3 cache_embedded_model.py
|
# RUN python3 cache_embedded_model.py
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import sentence_transformers
|
import sentence_transformers
|
||||||
|
|
||||||
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", "all-mpnet-base-v2")
|
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2")
|
||||||
model = sentence_transformers.SentenceTransformer(st_model)
|
model = sentence_transformers.SentenceTransformer(st_model)
|
||||||
|
@ -3,8 +3,7 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from extensions.openai.errors import ServiceUnavailableError
|
from extensions.openai.errors import ServiceUnavailableError
|
||||||
from extensions.openai.utils import debug_msg, float_list_to_base64
|
from extensions.openai.utils import debug_msg, float_list_to_base64
|
||||||
from modules import shared
|
from modules.logging_colors import logger
|
||||||
from transformers import AutoModel
|
|
||||||
|
|
||||||
embeddings_params_initialized = False
|
embeddings_params_initialized = False
|
||||||
|
|
||||||
@ -16,38 +15,44 @@ def initialize_embedding_params():
|
|||||||
'''
|
'''
|
||||||
global embeddings_params_initialized
|
global embeddings_params_initialized
|
||||||
if not embeddings_params_initialized:
|
if not embeddings_params_initialized:
|
||||||
global st_model, embeddings_model, embeddings_device
|
|
||||||
from extensions.openai.script import params
|
from extensions.openai.script import params
|
||||||
|
|
||||||
|
global st_model, embeddings_model, embeddings_device
|
||||||
|
|
||||||
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", params.get('embedding_model', 'all-mpnet-base-v2'))
|
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", params.get('embedding_model', 'all-mpnet-base-v2'))
|
||||||
embeddings_model = None
|
embeddings_model = None
|
||||||
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
|
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
|
||||||
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", params.get('embedding_device', 'cpu'))
|
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", params.get('embedding_device', 'cpu'))
|
||||||
if embeddings_device.lower() == 'auto':
|
if embeddings_device.lower() == 'auto':
|
||||||
embeddings_device = None
|
embeddings_device = None
|
||||||
|
|
||||||
embeddings_params_initialized = True
|
embeddings_params_initialized = True
|
||||||
|
|
||||||
|
|
||||||
def load_embedding_model(model: str):
|
def load_embedding_model(model: str):
|
||||||
|
try:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
logger.error("The sentence_transformers module has not been found. Please install it manually with pip install -U sentence-transformers.")
|
||||||
|
raise ModuleNotFoundError
|
||||||
|
|
||||||
initialize_embedding_params()
|
initialize_embedding_params()
|
||||||
global embeddings_device, embeddings_model
|
global embeddings_device, embeddings_model
|
||||||
try:
|
try:
|
||||||
print(f"Try embedding model: {model} on {embeddings_device}")
|
print(f"Try embedding model: {model} on {embeddings_device}")
|
||||||
trust = shared.args.trust_remote_code
|
embeddings_model = SentenceTransformer(model, device=embeddings_device)
|
||||||
if embeddings_device == 'cpu':
|
print(f"Loaded embedding model: {model}")
|
||||||
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=trust).to("cpu", dtype=float)
|
|
||||||
else: #use the auto mode
|
|
||||||
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=trust)
|
|
||||||
print(f"\nLoaded embedding model: {model} on {embeddings_model.device}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
embeddings_model = None
|
embeddings_model = None
|
||||||
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
|
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
|
||||||
|
|
||||||
|
|
||||||
def get_embeddings_model() -> AutoModel:
|
def get_embeddings_model():
|
||||||
initialize_embedding_params()
|
initialize_embedding_params()
|
||||||
global embeddings_model, st_model
|
global embeddings_model, st_model
|
||||||
if st_model and not embeddings_model:
|
if st_model and not embeddings_model:
|
||||||
load_embedding_model(st_model) # lazy load the model
|
load_embedding_model(st_model) # lazy load the model
|
||||||
|
|
||||||
return embeddings_model
|
return embeddings_model
|
||||||
|
|
||||||
|
|
||||||
@ -66,9 +71,7 @@ def get_embeddings(input: list) -> np.ndarray:
|
|||||||
|
|
||||||
|
|
||||||
def embeddings(input: list, encoding_format: str) -> dict:
|
def embeddings(input: list, encoding_format: str) -> dict:
|
||||||
|
|
||||||
embeddings = get_embeddings(input)
|
embeddings = get_embeddings(input)
|
||||||
|
|
||||||
if encoding_format == "base64":
|
if encoding_format == "base64":
|
||||||
data = [{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} for n, emb in enumerate(embeddings)]
|
data = [{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} for n, emb in enumerate(embeddings)]
|
||||||
else:
|
else:
|
||||||
@ -85,5 +88,4 @@ def embeddings(input: list, encoding_format: str) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
debug_msg(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
debug_msg(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -31,6 +31,8 @@ from .typing import (
|
|||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
DecodeRequest,
|
DecodeRequest,
|
||||||
DecodeResponse,
|
DecodeResponse,
|
||||||
|
EmbeddingsRequest,
|
||||||
|
EmbeddingsResponse,
|
||||||
EncodeRequest,
|
EncodeRequest,
|
||||||
EncodeResponse,
|
EncodeResponse,
|
||||||
LoadModelRequest,
|
LoadModelRequest,
|
||||||
@ -41,7 +43,7 @@ from .typing import (
|
|||||||
|
|
||||||
params = {
|
params = {
|
||||||
'embedding_device': 'cpu',
|
'embedding_device': 'cpu',
|
||||||
'embedding_model': 'all-mpnet-base-v2',
|
'embedding_model': 'sentence-transformers/all-mpnet-base-v2',
|
||||||
'sd_webui_url': '',
|
'sd_webui_url': '',
|
||||||
'debug': 0
|
'debug': 0
|
||||||
}
|
}
|
||||||
@ -196,19 +198,16 @@ async def handle_image_generation(request: Request):
|
|||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/embeddings")
|
@app.post("/v1/embeddings", response_model=EmbeddingsResponse)
|
||||||
async def handle_embeddings(request: Request):
|
async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
|
||||||
body = await request.json()
|
input = request_data.input
|
||||||
encoding_format = body.get("encoding_format", "")
|
|
||||||
|
|
||||||
input = body.get('input', body.get('text', ''))
|
|
||||||
if not input:
|
if not input:
|
||||||
raise HTTPException(status_code=400, detail="Missing required argument input")
|
raise HTTPException(status_code=400, detail="Missing required argument input")
|
||||||
|
|
||||||
if type(input) is str:
|
if type(input) is str:
|
||||||
input = [input]
|
input = [input]
|
||||||
|
|
||||||
response = OAIembeddings.embeddings(input, encoding_format)
|
response = OAIembeddings.embeddings(input, request_data.encoding_format)
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
|
@ -154,6 +154,19 @@ class LoadModelRequest(BaseModel):
|
|||||||
settings: dict | None = None
|
settings: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsRequest(BaseModel):
|
||||||
|
input: str | List[str]
|
||||||
|
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")
|
||||||
|
encoding_format: str = Field(default="float", description="Can be float or base64.")
|
||||||
|
user: str | None = Field(default=None, description="Unused parameter.")
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsResponse(BaseModel):
|
||||||
|
index: int
|
||||||
|
embedding: List[float]
|
||||||
|
object: str = "embedding"
|
||||||
|
|
||||||
|
|
||||||
def to_json(obj):
|
def to_json(obj):
|
||||||
return json.dumps(obj.__dict__, indent=4)
|
return json.dumps(obj.__dict__, indent=4)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user