text-generation-webui/extensions/openai/embeddings.py

87 lines
3.5 KiB
Python
Raw Normal View History

import os
2023-09-16 05:11:16 +02:00
import numpy as np
2023-09-16 05:11:16 +02:00
from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.utils import debug_msg, float_list_to_base64
from sentence_transformers import SentenceTransformer
embeddings_params_initialized = False
def initialize_embedding_params():
'''
using 'lazy loading' to avoid circular import
so this function will be executed only once
'''
global embeddings_params_initialized
if not embeddings_params_initialized:
global st_model, embeddings_model, embeddings_device
from extensions.openai.script import params
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", params.get('embedding_model', 'all-mpnet-base-v2'))
embeddings_model = None
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", params.get('embedding_device', 'cpu'))
if embeddings_device.lower() == 'auto':
embeddings_device = None
embeddings_params_initialized = True
2023-09-16 05:11:16 +02:00
def load_embedding_model(model: str) -> SentenceTransformer:
initialize_embedding_params()
global embeddings_device, embeddings_model
try:
print(f"Try embedding model: {model} on {embeddings_device}")
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer
embeddings_model = SentenceTransformer(model, device=embeddings_device)
# ... embeddings_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
print(f"\nLoaded embedding model: {model} on {embeddings_model.device} [always seems to say 'cpu', even if 'cuda'], max sequence length: {embeddings_model.max_seq_length}")
except Exception as e:
embeddings_model = None
2023-07-12 20:33:25 +02:00
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
def get_embeddings_model() -> SentenceTransformer:
initialize_embedding_params()
global embeddings_model, st_model
if st_model and not embeddings_model:
load_embedding_model(st_model) # lazy load the model
return embeddings_model
2023-07-12 20:33:25 +02:00
def get_embeddings_model_name() -> str:
initialize_embedding_params()
global st_model
return st_model
2023-07-12 20:33:25 +02:00
def get_embeddings(input: list) -> np.ndarray:
model = get_embeddings_model()
debug_msg(f"embedding model : {model}")
embedding = model.encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False)
debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will
return embedding
2023-09-16 05:11:16 +02:00
def embeddings(input: list, encoding_format: str) -> dict:
embeddings = get_embeddings(input)
if encoding_format == "base64":
data = [{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} for n, emb in enumerate(embeddings)]
else:
data = [{"object": "embedding", "embedding": emb.tolist(), "index": n} for n, emb in enumerate(embeddings)]
response = {
"object": "list",
"data": data,
"model": st_model, # return the real model
"usage": {
"prompt_tokens": 0,
"total_tokens": 0,
}
}
debug_msg(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
2023-07-12 20:33:25 +02:00
return response