[Fix] fix openai embedding_model loading as str (#4147)

This commit is contained in:
俞航 2023-11-06 07:42:45 +08:00 committed by GitHub
parent e18a0460d4
commit 84d957ba62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -26,23 +26,21 @@ def load_embedding_model(model: str) -> SentenceTransformer:
initialize_embedding_params() initialize_embedding_params()
global embeddings_device, embeddings_model global embeddings_device, embeddings_model
try: try:
embeddings_model = 'loading...' # flag print(f"\Try embedding model: {model} on {embeddings_device}")
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer # see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer
emb_model = SentenceTransformer(model, device=embeddings_device) embeddings_model = SentenceTransformer(model, device=embeddings_device)
# ... emb_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM # ... embeddings_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
print(f"\nLoaded embedding model: {model} on {emb_model.device} [always seems to say 'cpu', even if 'cuda'], max sequence length: {emb_model.max_seq_length}") print(f"\nLoaded embedding model: {model} on {embeddings_model.device} [always seems to say 'cpu', even if 'cuda'], max sequence length: {embeddings_model.max_seq_length}")
except Exception as e: except Exception as e:
embeddings_model = None embeddings_model = None
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e)) raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
return emb_model
def get_embeddings_model() -> SentenceTransformer: def get_embeddings_model() -> SentenceTransformer:
initialize_embedding_params() initialize_embedding_params()
global embeddings_model, st_model global embeddings_model, st_model
if st_model and not embeddings_model: if st_model and not embeddings_model:
embeddings_model = load_embedding_model(st_model) # lazy load the model load_embedding_model(st_model) # lazy load the model
return embeddings_model return embeddings_model
@ -53,7 +51,11 @@ def get_embeddings_model_name() -> str:
def get_embeddings(input: list) -> np.ndarray: def get_embeddings(input: list) -> np.ndarray:
return get_embeddings_model().encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False, device=embeddings_device) model = get_embeddings_model()
debug_msg(f"embedding model : {model}")
embedding = model.encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False)
debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will
return embedding
def embeddings(input: list, encoding_format: str) -> dict: def embeddings(input: list, encoding_format: str) -> dict: