mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-26 06:10:39 +01:00
78 lines
2.4 KiB
Python
78 lines
2.4 KiB
Python
|
from modules import shared
|
||
|
from modules.utils import get_available_models
|
||
|
from modules.models import load_model, unload_model
|
||
|
from modules.models_settings import (get_model_settings_from_yamls,
|
||
|
update_model_parameters)
|
||
|
|
||
|
from extensions.openai.embeddings import get_embeddings_model_name
|
||
|
from extensions.openai.errors import *
|
||
|
|
||
|
def get_current_model_list() -> list:
|
||
|
return [ shared.model_name ] # The real chat/completions model, maybe "None"
|
||
|
|
||
|
def get_pseudo_model_list() -> list:
|
||
|
return [ # these are expected by so much, so include some here as a dummy
|
||
|
'gpt-3.5-turbo',
|
||
|
'text-embedding-ada-002',
|
||
|
]
|
||
|
|
||
|
def load_model(model_name: str) -> dict:
|
||
|
resp = {
|
||
|
"id": model_name,
|
||
|
"object": "engine",
|
||
|
"owner": "self",
|
||
|
"ready": True,
|
||
|
}
|
||
|
if model_name not in get_pseudo_model_list() + [ get_embeddings_model_name() ] + get_current_model_list(): # Real model only
|
||
|
# No args. Maybe it works anyways!
|
||
|
# TODO: hack some heuristics into args for better results
|
||
|
|
||
|
shared.model_name = model_name
|
||
|
unload_model()
|
||
|
|
||
|
model_settings = get_model_settings_from_yamls(shared.model_name)
|
||
|
shared.settings.update(model_settings)
|
||
|
update_model_parameters(model_settings, initial=True)
|
||
|
|
||
|
if shared.settings['mode'] != 'instruct':
|
||
|
shared.settings['instruction_template'] = None
|
||
|
|
||
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||
|
|
||
|
if not shared.model: # load failed.
|
||
|
shared.model_name = "None"
|
||
|
raise OpenAIError(f"Model load failed for: {shared.model_name}")
|
||
|
|
||
|
return resp
|
||
|
|
||
|
|
||
|
def list_models(is_legacy: bool = False) -> dict:
|
||
|
# TODO: Lora's?
|
||
|
all_model_list = get_current_model_list() + [ get_embeddings_model_name() ] + get_pseudo_model_list() + get_available_models()
|
||
|
|
||
|
models = {}
|
||
|
|
||
|
if is_legacy:
|
||
|
models = [{ "id": id, "object": "engine", "owner": "user", "ready": True } for id in all_model_list ]
|
||
|
if not shared.model:
|
||
|
models[0]['ready'] = False
|
||
|
else:
|
||
|
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]
|
||
|
|
||
|
resp = {
|
||
|
"object": "list",
|
||
|
"data": models,
|
||
|
}
|
||
|
|
||
|
return resp
|
||
|
|
||
|
|
||
|
def model_info(model_name: str) -> dict:
|
||
|
return {
|
||
|
"id": model_name,
|
||
|
"object": "model",
|
||
|
"owned_by": "user",
|
||
|
"permission": []
|
||
|
}
|
||
|
|