Alternative solution to "get next logits" deadlock (#6106)

This commit is contained in:
oobabooga 2024-06-13 19:33:15 -07:00
parent 9aef01551d
commit 0f3a423de1

View File

@ -16,19 +16,24 @@ def get_next_logits(*args, **kwargs):
if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']: if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']:
shared.model, shared.tokenizer = load_model(shared.previous_model_name) shared.model, shared.tokenizer = load_model(shared.previous_model_name)
needs_lock = kwargs.get('use_samplers', False)
if needs_lock:
shared.generation_lock.acquire() shared.generation_lock.acquire()
try: try:
result = _get_next_logits(*args, **kwargs) result = _get_next_logits(*args, **kwargs)
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
result = None result = None
if needs_lock:
models.last_generation_time = time.time() models.last_generation_time = time.time()
shared.generation_lock.release() shared.generation_lock.release()
return result return result
def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False): def _get_next_logits(prompt, state, use_samplers=False, previous="", top_logits=25, return_dict=False):
if shared.model is None: if shared.model is None:
logger.error("No model is loaded! Select one in the Model tab.") logger.error("No model is loaded! Select one in the Model tab.")
return 'Error: No model is loaded1 Select one in the Model tab.', previous return 'Error: No model is loaded1 Select one in the Model tab.', previous