mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
--idle-timeout flag to unload the model if unused for N minutes (#6026)
This commit is contained in:
parent
818b4e0354
commit
9f77ed1b98
@ -308,9 +308,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||||||
'internal': output['internal']
|
'internal': output['internal']
|
||||||
}
|
}
|
||||||
|
|
||||||
if shared.model_name == 'None' or shared.model is None:
|
|
||||||
raise ValueError("No model is loaded! Select one in the Model tab.")
|
|
||||||
|
|
||||||
# Generate the prompt
|
# Generate the prompt
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'_continue': _continue,
|
'_continue': _continue,
|
||||||
@ -355,11 +352,6 @@ def impersonate_wrapper(text, state):
|
|||||||
|
|
||||||
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
|
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
|
||||||
|
|
||||||
if shared.model_name == 'None' or shared.model is None:
|
|
||||||
logger.error("No model is loaded! Select one in the Model tab.")
|
|
||||||
yield '', static_output
|
|
||||||
return
|
|
||||||
|
|
||||||
prompt = generate_chat_prompt('', state, impersonate=True)
|
prompt = generate_chat_prompt('', state, impersonate=True)
|
||||||
stopping_strings = get_stopping_strings(state)
|
stopping_strings = get_stopping_strings(state)
|
||||||
|
|
||||||
|
@ -1,14 +1,32 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import is_torch_npu_available, is_torch_xpu_available
|
from transformers import is_torch_npu_available, is_torch_xpu_available
|
||||||
|
|
||||||
from modules import sampler_hijack, shared
|
from modules import models, sampler_hijack, shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
from modules.models import load_model
|
||||||
from modules.text_generation import generate_reply
|
from modules.text_generation import generate_reply
|
||||||
|
|
||||||
global_scores = None
|
global_scores = None
|
||||||
|
|
||||||
|
|
||||||
def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False):
|
def get_next_logits(*args, **kwargs):
|
||||||
|
if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']:
|
||||||
|
shared.model, shared.tokenizer = load_model(shared.previous_model_name)
|
||||||
|
|
||||||
|
shared.generation_lock.acquire()
|
||||||
|
try:
|
||||||
|
result = _get_next_logits(*args, **kwargs)
|
||||||
|
except:
|
||||||
|
result = None
|
||||||
|
|
||||||
|
models.last_generation_time = time.time()
|
||||||
|
shared.generation_lock.release()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False):
|
||||||
if shared.model is None:
|
if shared.model is None:
|
||||||
logger.error("No model is loaded! Select one in the Model tab.")
|
logger.error("No model is loaded! Select one in the Model tab.")
|
||||||
return 'Error: No model is loaded1 Select one in the Model tab.', previous
|
return 'Error: No model is loaded1 Select one in the Model tab.', previous
|
||||||
|
@ -61,6 +61,9 @@ if shared.args.deepspeed:
|
|||||||
sampler_hijack.hijack_samplers()
|
sampler_hijack.hijack_samplers()
|
||||||
|
|
||||||
|
|
||||||
|
last_generation_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_name, loader=None):
|
def load_model(model_name, loader=None):
|
||||||
logger.info(f"Loading \"{model_name}\"")
|
logger.info(f"Loading \"{model_name}\"")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
@ -428,6 +431,7 @@ def clear_torch_cache():
|
|||||||
|
|
||||||
def unload_model():
|
def unload_model():
|
||||||
shared.model = shared.tokenizer = None
|
shared.model = shared.tokenizer = None
|
||||||
|
shared.previous_model_name = shared.model_name
|
||||||
shared.model_name = 'None'
|
shared.model_name = 'None'
|
||||||
shared.lora_names = []
|
shared.lora_names = []
|
||||||
shared.model_dirty_from_training = False
|
shared.model_dirty_from_training = False
|
||||||
@ -437,3 +441,21 @@ def unload_model():
|
|||||||
def reload_model():
|
def reload_model():
|
||||||
unload_model()
|
unload_model()
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
|
|
||||||
|
|
||||||
|
def unload_model_if_idle():
|
||||||
|
global last_generation_time
|
||||||
|
|
||||||
|
logger.info(f"Setting a timeout of {shared.args.idle_timeout} minutes to unload the model in case of inactivity.")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
shared.generation_lock.acquire()
|
||||||
|
try:
|
||||||
|
if time.time() - last_generation_time > shared.args.idle_timeout * 60:
|
||||||
|
if shared.model is not None:
|
||||||
|
logger.info("Unloading the model for inactivity.")
|
||||||
|
unload_model()
|
||||||
|
finally:
|
||||||
|
shared.generation_lock.release()
|
||||||
|
|
||||||
|
time.sleep(60)
|
||||||
|
@ -13,6 +13,7 @@ from modules.logging_colors import logger
|
|||||||
model = None
|
model = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
model_name = 'None'
|
model_name = 'None'
|
||||||
|
previous_model_name = 'None'
|
||||||
is_seq2seq = False
|
is_seq2seq = False
|
||||||
model_dirty_from_training = False
|
model_dirty_from_training = False
|
||||||
lora_names = []
|
lora_names = []
|
||||||
@ -84,6 +85,7 @@ group.add_argument('--settings', type=str, help='Load the default interface sett
|
|||||||
group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
|
group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
|
||||||
group.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
group.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||||
group.add_argument('--chat-buttons', action='store_true', help='Show buttons on the chat tab instead of a hover menu.')
|
group.add_argument('--chat-buttons', action='store_true', help='Show buttons on the chat tab instead of a hover menu.')
|
||||||
|
group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.')
|
||||||
|
|
||||||
# Model loader
|
# Model loader
|
||||||
group = parser.add_argument_group('Model loader')
|
group = parser.add_argument_group('Model loader')
|
||||||
|
@ -16,6 +16,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
from modules import models
|
||||||
from modules.cache_utils import process_llamacpp_cache
|
from modules.cache_utils import process_llamacpp_cache
|
||||||
from modules.callbacks import (
|
from modules.callbacks import (
|
||||||
Iteratorize,
|
Iteratorize,
|
||||||
@ -27,15 +28,19 @@ from modules.grammar.grammar_utils import initialize_grammar
|
|||||||
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
|
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
|
||||||
from modules.html_generator import generate_basic_html
|
from modules.html_generator import generate_basic_html
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.models import clear_torch_cache
|
from modules.models import clear_torch_cache, load_model
|
||||||
|
|
||||||
|
|
||||||
def generate_reply(*args, **kwargs):
|
def generate_reply(*args, **kwargs):
|
||||||
|
if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']:
|
||||||
|
shared.model, shared.tokenizer = load_model(shared.previous_model_name)
|
||||||
|
|
||||||
shared.generation_lock.acquire()
|
shared.generation_lock.acquire()
|
||||||
try:
|
try:
|
||||||
for result in _generate_reply(*args, **kwargs):
|
for result in _generate_reply(*args, **kwargs):
|
||||||
yield result
|
yield result
|
||||||
finally:
|
finally:
|
||||||
|
models.last_generation_time = time.time()
|
||||||
shared.generation_lock.release()
|
shared.generation_lock.release()
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Lock
|
from threading import Lock, Thread
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ from modules import (
|
|||||||
)
|
)
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model
|
from modules.models import load_model, unload_model_if_idle
|
||||||
from modules.models_settings import (
|
from modules.models_settings import (
|
||||||
get_fallback_settings,
|
get_fallback_settings,
|
||||||
get_model_metadata,
|
get_model_metadata,
|
||||||
@ -245,6 +245,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
shared.generation_lock = Lock()
|
shared.generation_lock = Lock()
|
||||||
|
|
||||||
|
if shared.args.idle_timeout > 0:
|
||||||
|
timer_thread = Thread(target=unload_model_if_idle)
|
||||||
|
timer_thread.daemon = True
|
||||||
|
timer_thread.start()
|
||||||
|
|
||||||
if shared.args.nowebui:
|
if shared.args.nowebui:
|
||||||
# Start the API in standalone mode
|
# Start the API in standalone mode
|
||||||
shared.args.extensions = [x for x in shared.args.extensions if x != 'gallery']
|
shared.args.extensions = [x for x in shared.args.extensions if x != 'gallery']
|
||||||
|
Loading…
Reference in New Issue
Block a user