2023-07-11 23:50:08 +02:00
import os
2023-09-16 05:11:16 +02:00
2023-07-24 16:28:12 +02:00
import numpy as np
2023-09-16 05:11:16 +02:00
from extensions . openai . errors import ServiceUnavailableError
from extensions . openai . utils import debug_msg , float_list_to_base64
from sentence_transformers import SentenceTransformer
2023-07-11 23:50:08 +02:00
st_model = os . environ [ " OPENEDAI_EMBEDDING_MODEL " ] if " OPENEDAI_EMBEDDING_MODEL " in os . environ else " all-mpnet-base-v2 "
embeddings_model = None
2023-07-24 16:28:12 +02:00
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
embeddings_device = os . environ . get ( " OPENEDAI_EMBEDDING_DEVICE " , " cpu " )
if embeddings_device . lower ( ) == ' auto ' :
embeddings_device = None
2023-07-11 23:50:08 +02:00
2023-09-16 05:11:16 +02:00
2023-07-24 16:28:12 +02:00
def load_embedding_model ( model : str ) - > SentenceTransformer :
global embeddings_device , embeddings_model
2023-07-11 23:50:08 +02:00
try :
2023-09-16 05:11:16 +02:00
embeddings_model = ' loading... ' # flag
2023-07-24 16:28:12 +02:00
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer
emb_model = SentenceTransformer ( model , device = embeddings_device )
# ... emb_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
print ( f " \n Loaded embedding model: { model } on { emb_model . device } [always seems to say ' cpu ' , even if ' cuda ' ], max sequence length: { emb_model . max_seq_length } " )
2023-07-11 23:50:08 +02:00
except Exception as e :
2023-07-24 16:28:12 +02:00
embeddings_model = None
2023-07-12 20:33:25 +02:00
raise ServiceUnavailableError ( f " Error: Failed to load embedding model: { model } " , internal_message = repr ( e ) )
2023-07-11 23:50:08 +02:00
return emb_model
2023-07-12 20:33:25 +02:00
2023-07-24 16:28:12 +02:00
def get_embeddings_model ( ) - > SentenceTransformer :
2023-07-11 23:50:08 +02:00
global embeddings_model , st_model
if st_model and not embeddings_model :
2023-07-12 20:33:25 +02:00
embeddings_model = load_embedding_model ( st_model ) # lazy load the model
2023-07-11 23:50:08 +02:00
return embeddings_model
2023-07-12 20:33:25 +02:00
2023-07-24 16:28:12 +02:00
def get_embeddings_model_name ( ) - > str :
2023-07-11 23:50:08 +02:00
global st_model
return st_model
2023-07-12 20:33:25 +02:00
2023-07-24 16:28:12 +02:00
def get_embeddings ( input : list ) - > np . ndarray :
return get_embeddings_model ( ) . encode ( input , convert_to_numpy = True , normalize_embeddings = True , convert_to_tensor = False , device = embeddings_device )
2023-09-16 05:11:16 +02:00
2023-07-24 16:28:12 +02:00
def embeddings ( input : list , encoding_format : str ) - > dict :
2023-07-11 23:50:08 +02:00
2023-07-24 16:28:12 +02:00
embeddings = get_embeddings ( input )
2023-07-11 23:50:08 +02:00
if encoding_format == " base64 " :
data = [ { " object " : " embedding " , " embedding " : float_list_to_base64 ( emb ) , " index " : n } for n , emb in enumerate ( embeddings ) ]
else :
2023-07-24 16:28:12 +02:00
data = [ { " object " : " embedding " , " embedding " : emb . tolist ( ) , " index " : n } for n , emb in enumerate ( embeddings ) ]
2023-07-11 23:50:08 +02:00
response = {
" object " : " list " ,
" data " : data ,
" model " : st_model , # return the real model
" usage " : {
" prompt_tokens " : 0 ,
" total_tokens " : 0 ,
}
}
debug_msg ( f " Embeddings return size: { len ( embeddings [ 0 ] ) } , number: { len ( embeddings ) } " )
2023-07-12 20:33:25 +02:00
return response