diff --git a/modules/logits.py b/modules/logits.py index 3e793bd0..1fe2e73e 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -1,4 +1,5 @@ import time +import traceback import torch from transformers import is_torch_npu_available, is_torch_xpu_available @@ -18,7 +19,8 @@ def get_next_logits(*args, **kwargs): shared.generation_lock.acquire() try: result = _get_next_logits(*args, **kwargs) - except: + except Exception: + traceback.print_exc() result = None models.last_generation_time = time.time() @@ -84,7 +86,14 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur topk_values = [float(i) for i in topk_values] output = {} for row in list(zip(topk_values, tokens)): - output[row[1]] = row[0] + key = row[1] + if isinstance(key, bytes): + try: + key = key.decode() + except: + key = key.decode('latin') + + output[key] = row[0] return output else: