Fix logprobs tokens in OpenAI API (#5339)

This commit is contained in:
lmg-anon 2024-01-22 08:07:42 -03:00 committed by GitHub
parent b5cabb6e9d
commit db1da9f98d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 4 deletions

View File

@ -22,7 +22,7 @@ from modules.chat import (
load_instruction_template_memoized load_instruction_template_memoized
) )
from modules.presets import load_preset_memoized from modules.presets import load_preset_memoized
from modules.text_generation import decode, encode, generate_reply from modules.text_generation import decode, encode, generate_reply, get_reply_from_output_ids
class LogitsBiasProcessor(LogitsProcessor): class LogitsBiasProcessor(LogitsProcessor):
@ -56,7 +56,7 @@ class LogprobProcessor(LogitsProcessor):
if self.logprobs is not None: # 0-5 if self.logprobs is not None: # 0-5
log_e_probabilities = F.log_softmax(logits, dim=1) log_e_probabilities = F.log_softmax(logits, dim=1)
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1) top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
top_tokens = [decode(tok) for tok in top_indices[0]] top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
top_probs = [float(x) for x in top_values[0]] top_probs = [float(x) for x in top_values[0]]
self.token_alternatives = dict(zip(top_tokens, top_probs)) self.token_alternatives = dict(zip(top_tokens, top_probs))
debug_msg(repr(self)) debug_msg(repr(self))

View File

@ -268,8 +268,8 @@ def apply_stopping_strings(reply, all_stop_strings):
return reply, stop_found return reply, stop_found
def get_reply_from_output_ids(output_ids, state, starting_from=0): def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
reply = decode(output_ids[starting_from:], state['skip_special_tokens']) reply = decode(output_ids[starting_from:], state['skip_special_tokens'] if state else True)
# Handle tokenizers that do not add the leading space for the first token # Handle tokenizers that do not add the leading space for the first token
if (hasattr(shared.tokenizer, 'convert_ids_to_tokens') and len(output_ids) > starting_from) and not reply.startswith(' '): if (hasattr(shared.tokenizer, 'convert_ids_to_tokens') and len(output_ids) > starting_from) and not reply.startswith(' '):