Add an API endpoint to reload the last-used model

This commit is contained in:
anon-contributor-0 2024-02-15 20:26:11 -05:00
parent b19d239a60
commit b9352edf12
2 changed files with 21 additions and 2 deletions

View File

@ -26,7 +26,7 @@ from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared from extensions.openai.utils import _start_cloudflared
from modules import shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.models import unload_model from modules.models import unload_model, load_last_model
from modules.text_generation import stop_everything_event from modules.text_generation import stop_everything_event
from .typing import ( from .typing import (
@ -325,6 +325,21 @@ async def handle_load_model(request_data: LoadModelRequest):
return HTTPException(status_code=400, detail="Failed to load the model.") return HTTPException(status_code=400, detail="Failed to load the model.")
@app.post("/v1/internal/model/loadlast", dependencies=check_admin_key)
async def handle_load_last_model():
'''
This endpoint is experimental and may change in the future.
Loads the last model used before it was unloaded.
'''
try:
load_last_model()
return JSONResponse(content="OK")
except:
traceback.print_exc()
return HTTPException(status_code=400, detail="Failed to load the last-used model.")
@app.post("/v1/internal/model/unload", dependencies=check_admin_key) @app.post("/v1/internal/model/unload", dependencies=check_admin_key)
async def handle_unload_model(): async def handle_unload_model():
unload_model() unload_model()

View File

@ -396,9 +396,13 @@ def unload_model():
clear_torch_cache() clear_torch_cache()
def load_last_model():
shared.model, shared.tokenizer = load_model(shared.previous_model_name)
def reload_model(): def reload_model():
unload_model() unload_model()
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.previous_model_name)
def unload_model_if_idle(): def unload_model_if_idle():