Include trust remote code usage in openai api's embedder (#4513)

This commit is contained in:
MrMojoR 2023-11-08 15:25:43 +01:00 committed by GitHub
parent 6c7aad11f3
commit 1754a3761b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,7 +3,9 @@ import os
import numpy as np import numpy as np
from extensions.openai.errors import ServiceUnavailableError from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.utils import debug_msg, float_list_to_base64 from extensions.openai.utils import debug_msg, float_list_to_base64
from sentence_transformers import SentenceTransformer from transformers import AutoModel
from modules import shared
embeddings_params_initialized = False embeddings_params_initialized = False
@ -26,21 +28,23 @@ def initialize_embedding_params():
embeddings_params_initialized = True embeddings_params_initialized = True
def load_embedding_model(model: str) -> SentenceTransformer: def load_embedding_model(model: str):
initialize_embedding_params() initialize_embedding_params()
global embeddings_device, embeddings_model global embeddings_device, embeddings_model
try: try:
print(f"Try embedding model: {model} on {embeddings_device}") print(f"Try embedding model: {model} on {embeddings_device}")
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer trust = shared.args.trust_remote_code
embeddings_model = SentenceTransformer(model, device=embeddings_device) if embeddings_device == 'cpu':
# ... embeddings_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=trust).to("cpu", dtype=float)
print(f"\nLoaded embedding model: {model} on {embeddings_model.device} [always seems to say 'cpu', even if 'cuda'], max sequence length: {embeddings_model.max_seq_length}") else: #use the auto mode
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=trust)
print(f"\nLoaded embedding model: {model} on {embeddings_model.device}")
except Exception as e: except Exception as e:
embeddings_model = None embeddings_model = None
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e)) raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
def get_embeddings_model() -> SentenceTransformer: def get_embeddings_model() -> AutoModel:
initialize_embedding_params() initialize_embedding_params()
global embeddings_model, st_model global embeddings_model, st_model
if st_model and not embeddings_model: if st_model and not embeddings_model: