text-generation-webui/extensions/openai/models.py

84 lines
2.3 KiB
Python
Raw Normal View History

2023-09-11 23:49:30 +02:00
from modules import shared
from modules.logging_colors import logger
from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.utils import get_available_loras, get_available_models
2023-07-12 20:33:25 +02:00
2023-11-08 03:59:02 +01:00
def get_current_model_info():
return {
'model_name': shared.model_name,
'lora_names': shared.lora_names,
'loader': shared.args.loader
2023-11-08 03:59:02 +01:00
}
2023-11-08 04:59:27 +01:00
def list_models():
return {'model_names': get_available_models()[1:]}
def list_dummy_models():
2023-11-08 04:59:27 +01:00
result = {
"object": "list",
2023-11-08 04:59:27 +01:00
"data": []
}
models = get_available_models()[1:]
# these are expected by so much, so include some here as a dummy
for model in ['gpt-3.5-turbo', 'text-embedding-ada-002']:
2023-11-08 04:59:27 +01:00
result["data"].append(model_info_dict(model))
for model in models:
result["data"].append(model_info_dict(model))
2023-11-08 04:59:27 +01:00
return result
2023-11-08 04:59:27 +01:00
def model_info_dict(model_name: str) -> dict:
return {
"id": model_name,
"object": "model",
2023-11-08 04:59:27 +01:00
"created": 0,
"owned_by": "user"
}
2023-11-08 04:59:27 +01:00
def _load_model(data):
model_name = data["model_name"]
args = data.get("args", None)
settings = data.get("settings", None)
unload_model()
model_settings = get_model_metadata(model_name)
update_model_parameters(model_settings)
# Update shared.args with custom model loading settings
if args:
for k in args:
if hasattr(shared.args, k):
setattr(shared.args, k, args[k])
shared.model, shared.tokenizer = load_model(model_name)
# Update shared.settings with custom generation defaults
if settings:
for k in settings:
if k in shared.settings:
shared.settings[k] = settings[k]
if k == 'truncation_length':
logger.info(f"TRUNCATION LENGTH (UPDATED): {shared.settings['truncation_length']}")
elif k == 'instruction_template':
logger.info(f"INSTRUCTION TEMPLATE (UPDATED): {shared.settings['instruction_template']}")
def list_loras():
return {'lora_names': get_available_loras()[1:]}
def load_loras(lora_names):
add_lora_to_model(lora_names)
def unload_all_loras():
add_lora_to_model([])