mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Fix getting Phi-3-small-128k-instruct logits
This commit is contained in:
parent
bd7cc4234d
commit
ae86292159
@ -1,4 +1,5 @@
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
from transformers import is_torch_npu_available, is_torch_xpu_available
|
||||
@ -18,7 +19,8 @@ def get_next_logits(*args, **kwargs):
|
||||
shared.generation_lock.acquire()
|
||||
try:
|
||||
result = _get_next_logits(*args, **kwargs)
|
||||
except:
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
result = None
|
||||
|
||||
models.last_generation_time = time.time()
|
||||
@ -84,7 +86,14 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
|
||||
topk_values = [float(i) for i in topk_values]
|
||||
output = {}
|
||||
for row in list(zip(topk_values, tokens)):
|
||||
output[row[1]] = row[0]
|
||||
key = row[1]
|
||||
if isinstance(key, bytes):
|
||||
try:
|
||||
key = key.decode()
|
||||
except:
|
||||
key = key.decode('latin')
|
||||
|
||||
output[key] = row[0]
|
||||
|
||||
return output
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user