--idle-timeout flag to unload the model if unused for N minutes (#6026)

This commit is contained in:
oobabooga 2024-05-19 23:29:39 -03:00 committed by GitHub
parent 818b4e0354
commit 9f77ed1b98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 57 additions and 13 deletions

View File

@ -308,9 +308,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
'internal': output['internal'] 'internal': output['internal']
} }
if shared.model_name == 'None' or shared.model is None:
raise ValueError("No model is loaded! Select one in the Model tab.")
# Generate the prompt # Generate the prompt
kwargs = { kwargs = {
'_continue': _continue, '_continue': _continue,
@ -355,11 +352,6 @@ def impersonate_wrapper(text, state):
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
if shared.model_name == 'None' or shared.model is None:
logger.error("No model is loaded! Select one in the Model tab.")
yield '', static_output
return
prompt = generate_chat_prompt('', state, impersonate=True) prompt = generate_chat_prompt('', state, impersonate=True)
stopping_strings = get_stopping_strings(state) stopping_strings = get_stopping_strings(state)

View File

@ -1,14 +1,32 @@
import time
import torch import torch
from transformers import is_torch_npu_available, is_torch_xpu_available from transformers import is_torch_npu_available, is_torch_xpu_available
from modules import sampler_hijack, shared from modules import models, sampler_hijack, shared
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.models import load_model
from modules.text_generation import generate_reply from modules.text_generation import generate_reply
global_scores = None global_scores = None
def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False): def get_next_logits(*args, **kwargs):
if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']:
shared.model, shared.tokenizer = load_model(shared.previous_model_name)
shared.generation_lock.acquire()
try:
result = _get_next_logits(*args, **kwargs)
except:
result = None
models.last_generation_time = time.time()
shared.generation_lock.release()
return result
def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False):
if shared.model is None: if shared.model is None:
logger.error("No model is loaded! Select one in the Model tab.") logger.error("No model is loaded! Select one in the Model tab.")
return 'Error: No model is loaded1 Select one in the Model tab.', previous return 'Error: No model is loaded1 Select one in the Model tab.', previous

View File

@ -61,6 +61,9 @@ if shared.args.deepspeed:
sampler_hijack.hijack_samplers() sampler_hijack.hijack_samplers()
last_generation_time = time.time()
def load_model(model_name, loader=None): def load_model(model_name, loader=None):
logger.info(f"Loading \"{model_name}\"") logger.info(f"Loading \"{model_name}\"")
t0 = time.time() t0 = time.time()
@ -428,6 +431,7 @@ def clear_torch_cache():
def unload_model(): def unload_model():
shared.model = shared.tokenizer = None shared.model = shared.tokenizer = None
shared.previous_model_name = shared.model_name
shared.model_name = 'None' shared.model_name = 'None'
shared.lora_names = [] shared.lora_names = []
shared.model_dirty_from_training = False shared.model_dirty_from_training = False
@ -437,3 +441,21 @@ def unload_model():
def reload_model(): def reload_model():
unload_model() unload_model()
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
def unload_model_if_idle():
global last_generation_time
logger.info(f"Setting a timeout of {shared.args.idle_timeout} minutes to unload the model in case of inactivity.")
while True:
shared.generation_lock.acquire()
try:
if time.time() - last_generation_time > shared.args.idle_timeout * 60:
if shared.model is not None:
logger.info("Unloading the model for inactivity.")
unload_model()
finally:
shared.generation_lock.release()
time.sleep(60)

View File

@ -13,6 +13,7 @@ from modules.logging_colors import logger
model = None model = None
tokenizer = None tokenizer = None
model_name = 'None' model_name = 'None'
previous_model_name = 'None'
is_seq2seq = False is_seq2seq = False
model_dirty_from_training = False model_dirty_from_training = False
lora_names = [] lora_names = []
@ -84,6 +85,7 @@ group.add_argument('--settings', type=str, help='Load the default interface sett
group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
group.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') group.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
group.add_argument('--chat-buttons', action='store_true', help='Show buttons on the chat tab instead of a hover menu.') group.add_argument('--chat-buttons', action='store_true', help='Show buttons on the chat tab instead of a hover menu.')
group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.')
# Model loader # Model loader
group = parser.add_argument_group('Model loader') group = parser.add_argument_group('Model loader')

View File

@ -16,6 +16,7 @@ from transformers import (
) )
import modules.shared as shared import modules.shared as shared
from modules import models
from modules.cache_utils import process_llamacpp_cache from modules.cache_utils import process_llamacpp_cache
from modules.callbacks import ( from modules.callbacks import (
Iteratorize, Iteratorize,
@ -27,15 +28,19 @@ from modules.grammar.grammar_utils import initialize_grammar
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
from modules.html_generator import generate_basic_html from modules.html_generator import generate_basic_html
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.models import clear_torch_cache from modules.models import clear_torch_cache, load_model
def generate_reply(*args, **kwargs): def generate_reply(*args, **kwargs):
if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']:
shared.model, shared.tokenizer = load_model(shared.previous_model_name)
shared.generation_lock.acquire() shared.generation_lock.acquire()
try: try:
for result in _generate_reply(*args, **kwargs): for result in _generate_reply(*args, **kwargs):
yield result yield result
finally: finally:
models.last_generation_time = time.time()
shared.generation_lock.release() shared.generation_lock.release()

View File

@ -32,7 +32,7 @@ import sys
import time import time
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from threading import Lock from threading import Lock, Thread
import yaml import yaml
@ -52,7 +52,7 @@ from modules import (
) )
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model from modules.models import load_model, unload_model_if_idle
from modules.models_settings import ( from modules.models_settings import (
get_fallback_settings, get_fallback_settings,
get_model_metadata, get_model_metadata,
@ -245,6 +245,11 @@ if __name__ == "__main__":
shared.generation_lock = Lock() shared.generation_lock = Lock()
if shared.args.idle_timeout > 0:
timer_thread = Thread(target=unload_model_if_idle)
timer_thread.daemon = True
timer_thread.start()
if shared.args.nowebui: if shared.args.nowebui:
# Start the API in standalone mode # Start the API in standalone mode
shared.args.extensions = [x for x in shared.args.extensions if x != 'gallery'] shared.args.extensions = [x for x in shared.args.extensions if x != 'gallery']