From 1754a3761b9e03a575a3b9fb908905673dcfc658 Mon Sep 17 00:00:00 2001 From: MrMojoR Date: Wed, 8 Nov 2023 15:25:43 +0100 Subject: [PATCH] Include trust remote code usage in openai api's embedder (#4513) --- extensions/openai/embeddings.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/extensions/openai/embeddings.py b/extensions/openai/embeddings.py index 88ab1c30..a5b52d7b 100644 --- a/extensions/openai/embeddings.py +++ b/extensions/openai/embeddings.py @@ -3,7 +3,9 @@ import os import numpy as np from extensions.openai.errors import ServiceUnavailableError from extensions.openai.utils import debug_msg, float_list_to_base64 -from sentence_transformers import SentenceTransformer +from transformers import AutoModel + +from modules import shared embeddings_params_initialized = False @@ -26,21 +28,23 @@ def initialize_embedding_params(): embeddings_params_initialized = True -def load_embedding_model(model: str) -> SentenceTransformer: +def load_embedding_model(model: str): initialize_embedding_params() global embeddings_device, embeddings_model try: print(f"Try embedding model: {model} on {embeddings_device}") - # see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer - embeddings_model = SentenceTransformer(model, device=embeddings_device) - # ... embeddings_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM - print(f"\nLoaded embedding model: {model} on {embeddings_model.device} [always seems to say 'cpu', even if 'cuda'], max sequence length: {embeddings_model.max_seq_length}") + trust = shared.args.trust_remote_code + if embeddings_device == 'cpu': + embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=trust).to("cpu", dtype=float) + else: #use the auto mode + embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=trust) + print(f"\nLoaded embedding model: {model} on {embeddings_model.device}") except Exception as e: embeddings_model = None raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e)) -def get_embeddings_model() -> SentenceTransformer: +def get_embeddings_model() -> AutoModel: initialize_embedding_params() global embeddings_model, st_model if st_model and not embeddings_model: