diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index c5ef36e6..b1bb0fff 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -22,7 +22,7 @@ from modules.chat import ( load_instruction_template_memoized ) from modules.presets import load_preset_memoized -from modules.text_generation import decode, encode, generate_reply +from modules.text_generation import decode, encode, generate_reply, get_reply_from_output_ids class LogitsBiasProcessor(LogitsProcessor): @@ -56,7 +56,7 @@ class LogprobProcessor(LogitsProcessor): if self.logprobs is not None: # 0-5 log_e_probabilities = F.log_softmax(logits, dim=1) top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1) - top_tokens = [decode(tok) for tok in top_indices[0]] + top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]] top_probs = [float(x) for x in top_values[0]] self.token_alternatives = dict(zip(top_tokens, top_probs)) debug_msg(repr(self)) diff --git a/modules/text_generation.py b/modules/text_generation.py index c8562450..f4849840 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -268,8 +268,8 @@ def apply_stopping_strings(reply, all_stop_strings): return reply, stop_found -def get_reply_from_output_ids(output_ids, state, starting_from=0): - reply = decode(output_ids[starting_from:], state['skip_special_tokens']) +def get_reply_from_output_ids(output_ids, state=None, starting_from=0): + reply = decode(output_ids[starting_from:], state['skip_special_tokens'] if state else True) # Handle tokenizers that do not add the leading space for the first token if (hasattr(shared.tokenizer, 'convert_ids_to_tokens') and len(output_ids) > starting_from) and not reply.startswith(' '):