2023-07-11 23:50:08 +02:00
import os
2023-09-16 05:11:16 +02:00
2023-07-24 16:28:12 +02:00
import numpy as np
2023-09-16 05:11:16 +02:00
from extensions . openai . errors import ServiceUnavailableError
from extensions . openai . utils import debug_msg , float_list_to_base64
from sentence_transformers import SentenceTransformer
2023-07-11 23:50:08 +02:00
2023-09-18 03:39:29 +02:00
embeddings_params_initialized = False
2023-11-06 06:38:29 +01:00
2023-09-18 03:39:29 +02:00
def initialize_embedding_params ( ) :
2023-11-06 06:38:29 +01:00
'''
using ' lazy loading ' to avoid circular import
so this function will be executed only once
'''
2023-09-18 03:39:29 +02:00
global embeddings_params_initialized
if not embeddings_params_initialized :
global st_model , embeddings_model , embeddings_device
from extensions . openai . script import params
st_model = os . environ . get ( " OPENEDAI_EMBEDDING_MODEL " , params . get ( ' embedding_model ' , ' all-mpnet-base-v2 ' ) )
embeddings_model = None
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
embeddings_device = os . environ . get ( " OPENEDAI_EMBEDDING_DEVICE " , params . get ( ' embedding_device ' , ' cpu ' ) )
if embeddings_device . lower ( ) == ' auto ' :
embeddings_device = None
embeddings_params_initialized = True
2023-07-11 23:50:08 +02:00
2023-09-16 05:11:16 +02:00
2023-07-24 16:28:12 +02:00
def load_embedding_model ( model : str ) - > SentenceTransformer :
2023-09-18 03:39:29 +02:00
initialize_embedding_params ( )
2023-07-24 16:28:12 +02:00
global embeddings_device , embeddings_model
2023-07-11 23:50:08 +02:00
try :
2023-11-06 06:38:29 +01:00
print ( f " Try embedding model: { model } on { embeddings_device } " )
2023-07-24 16:28:12 +02:00
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer
2023-11-06 00:42:45 +01:00
embeddings_model = SentenceTransformer ( model , device = embeddings_device )
# ... embeddings_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
print ( f " \n Loaded embedding model: { model } on { embeddings_model . device } [always seems to say ' cpu ' , even if ' cuda ' ], max sequence length: { embeddings_model . max_seq_length } " )
2023-07-11 23:50:08 +02:00
except Exception as e :
2023-07-24 16:28:12 +02:00
embeddings_model = None
2023-07-12 20:33:25 +02:00
raise ServiceUnavailableError ( f " Error: Failed to load embedding model: { model } " , internal_message = repr ( e ) )
2023-07-24 16:28:12 +02:00
def get_embeddings_model ( ) - > SentenceTransformer :
2023-09-18 03:39:29 +02:00
initialize_embedding_params ( )
2023-07-11 23:50:08 +02:00
global embeddings_model , st_model
if st_model and not embeddings_model :
2023-11-06 00:42:45 +01:00
load_embedding_model ( st_model ) # lazy load the model
2023-07-11 23:50:08 +02:00
return embeddings_model
2023-07-12 20:33:25 +02:00
2023-07-24 16:28:12 +02:00
def get_embeddings_model_name ( ) - > str :
2023-09-18 03:39:29 +02:00
initialize_embedding_params ( )
2023-07-11 23:50:08 +02:00
global st_model
return st_model
2023-07-12 20:33:25 +02:00
2023-07-24 16:28:12 +02:00
def get_embeddings ( input : list ) - > np . ndarray :
2023-11-06 00:42:45 +01:00
model = get_embeddings_model ( )
debug_msg ( f " embedding model : { model } " )
embedding = model . encode ( input , convert_to_numpy = True , normalize_embeddings = True , convert_to_tensor = False )
2023-11-06 06:38:29 +01:00
debug_msg ( f " embedding result : { embedding } " ) # might be too long even for debug, use at you own will
2023-11-06 00:42:45 +01:00
return embedding
2023-07-24 16:28:12 +02:00
2023-09-16 05:11:16 +02:00
2023-07-24 16:28:12 +02:00
def embeddings ( input : list , encoding_format : str ) - > dict :
2023-07-11 23:50:08 +02:00
2023-07-24 16:28:12 +02:00
embeddings = get_embeddings ( input )
2023-07-11 23:50:08 +02:00
if encoding_format == " base64 " :
data = [ { " object " : " embedding " , " embedding " : float_list_to_base64 ( emb ) , " index " : n } for n , emb in enumerate ( embeddings ) ]
else :
2023-07-24 16:28:12 +02:00
data = [ { " object " : " embedding " , " embedding " : emb . tolist ( ) , " index " : n } for n , emb in enumerate ( embeddings ) ]
2023-07-11 23:50:08 +02:00
response = {
" object " : " list " ,
" data " : data ,
" model " : st_model , # return the real model
" usage " : {
" prompt_tokens " : 0 ,
" total_tokens " : 0 ,
}
}
debug_msg ( f " Embeddings return size: { len ( embeddings [ 0 ] ) } , number: { len ( embeddings ) } " )
2023-07-12 20:33:25 +02:00
return response