Use convert_ids_to_tokens instead of decode in logits endpoint

This preserves the llama tokenizer spaces.
This commit is contained in:
oobabooga 2023-11-19 09:22:08 -08:00
parent 8cf05c1b31
commit a2e6d00128

View File

@ -55,7 +55,10 @@ def get_next_logits(prompt, state, use_samplers, previous, return_dict=False):
if is_non_hf_exllamav1 or is_non_hf_llamacpp: if is_non_hf_exllamav1 or is_non_hf_llamacpp:
topk_indices = [i.expand((1, 1)) for i in topk_indices] topk_indices = [i.expand((1, 1)) for i in topk_indices]
tokens = [shared.tokenizer.decode(i) for i in topk_indices] if hasattr(shared.tokenizer, 'convert_ids_to_tokens'):
tokens = [shared.tokenizer.convert_ids_to_tokens(int(i)) for i in topk_indices]
else:
tokens = [shared.tokenizer.decode(i) for i in topk_indices]
if return_dict: if return_dict:
output = {} output = {}