mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-24 13:28:59 +01:00
Include trust remote code usage in openai api's embedder (#4513)
This commit is contained in:
parent
6c7aad11f3
commit
1754a3761b
@ -3,7 +3,9 @@ import os
|
||||
import numpy as np
|
||||
from extensions.openai.errors import ServiceUnavailableError
|
||||
from extensions.openai.utils import debug_msg, float_list_to_base64
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoModel
|
||||
|
||||
from modules import shared
|
||||
|
||||
embeddings_params_initialized = False
|
||||
|
||||
@ -26,21 +28,23 @@ def initialize_embedding_params():
|
||||
embeddings_params_initialized = True
|
||||
|
||||
|
||||
def load_embedding_model(model: str) -> SentenceTransformer:
|
||||
def load_embedding_model(model: str):
|
||||
initialize_embedding_params()
|
||||
global embeddings_device, embeddings_model
|
||||
try:
|
||||
print(f"Try embedding model: {model} on {embeddings_device}")
|
||||
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer
|
||||
embeddings_model = SentenceTransformer(model, device=embeddings_device)
|
||||
# ... embeddings_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
|
||||
print(f"\nLoaded embedding model: {model} on {embeddings_model.device} [always seems to say 'cpu', even if 'cuda'], max sequence length: {embeddings_model.max_seq_length}")
|
||||
trust = shared.args.trust_remote_code
|
||||
if embeddings_device == 'cpu':
|
||||
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=trust).to("cpu", dtype=float)
|
||||
else: #use the auto mode
|
||||
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=trust)
|
||||
print(f"\nLoaded embedding model: {model} on {embeddings_model.device}")
|
||||
except Exception as e:
|
||||
embeddings_model = None
|
||||
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
|
||||
|
||||
|
||||
def get_embeddings_model() -> SentenceTransformer:
|
||||
def get_embeddings_model() -> AutoModel:
|
||||
initialize_embedding_params()
|
||||
global embeddings_model, st_model
|
||||
if st_model and not embeddings_model:
|
||||
|
Loading…
Reference in New Issue
Block a user