From 0f3a423de11ce4196d840ece18d39ecba6f603fa Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 13 Jun 2024 19:33:15 -0700 Subject: [PATCH] Alternative solution to "get next logits" deadlock (#6106) --- modules/logits.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/modules/logits.py b/modules/logits.py index 1fe2e73e..447deb24 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -16,19 +16,24 @@ def get_next_logits(*args, **kwargs): if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']: shared.model, shared.tokenizer = load_model(shared.previous_model_name) - shared.generation_lock.acquire() + needs_lock = kwargs.get('use_samplers', False) + if needs_lock: + shared.generation_lock.acquire() + try: result = _get_next_logits(*args, **kwargs) except Exception: traceback.print_exc() result = None - models.last_generation_time = time.time() - shared.generation_lock.release() + if needs_lock: + models.last_generation_time = time.time() + shared.generation_lock.release() + return result -def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False): +def _get_next_logits(prompt, state, use_samplers=False, previous="", top_logits=25, return_dict=False): if shared.model is None: logger.error("No model is loaded! Select one in the Model tab.") return 'Error: No model is loaded1 Select one in the Model tab.', previous