From 7073665a10194015c0529eca790199d41fcea9f9 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 5 Feb 2024 02:31:24 -0300 Subject: [PATCH] Truncate long chat completions inputs (#5439) --- modules/chat.py | 47 +++++++++++++++++++++++++++++++++----- modules/text_generation.py | 13 +++++++---- 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 5380f1ac..348c5eba 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -166,18 +166,53 @@ def generate_chat_prompt(user_input, state, **kwargs): prompt = remove_extra_bos(prompt) return prompt - prompt = make_prompt(messages) - # Handle truncation max_length = get_max_prompt_length(state) - while len(messages) > 0 and get_encoded_length(prompt) > max_length: - # Try to save the system message - if len(messages) > 1 and messages[0]['role'] == 'system': + prompt = make_prompt(messages) + encoded_length = get_encoded_length(prompt) + + while len(messages) > 0 and encoded_length > max_length: + + # Remove old message, save system message + if len(messages) > 2 and messages[0]['role'] == 'system': messages.pop(1) - else: + + # Remove old message when no system message is present + elif len(messages) > 1 and messages[0]['role'] != 'system': messages.pop(0) + # Resort to truncating the user input + else: + + user_message = messages[-1]['content'] + + # Bisect the truncation point + left, right = 0, len(user_message) - 1 + + while right - left > 1: + mid = (left + right) // 2 + + messages[-1]['content'] = user_message[mid:] + prompt = make_prompt(messages) + encoded_length = get_encoded_length(prompt) + + if encoded_length <= max_length: + right = mid + else: + left = mid + + messages[-1]['content'] = user_message[right:] + prompt = make_prompt(messages) + encoded_length = get_encoded_length(prompt) + if encoded_length > max_length: + logger.error(f"Failed to build the chat prompt. The input is too long for the available context length.\n\nTruncation length: {state['truncation_length']}\nmax_new_tokens: {state['max_new_tokens']} (is it too high?)\nAvailable context length: {max_length}\n") + raise ValueError + else: + logger.warning(f"The input has been truncated. Context length: {state['truncation_length']}, max_new_tokens: {state['max_new_tokens']}.") + break + prompt = make_prompt(messages) + encoded_length = get_encoded_length(prompt) if also_return_rows: return prompt, [message['content'] for message in messages] diff --git a/modules/text_generation.py b/modules/text_generation.py index 198b7575..04625ab9 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -50,6 +50,11 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap else: generate_func = generate_reply_HF + if generate_func != generate_reply_HF and shared.args.verbose: + logger.info("PROMPT=") + print(question) + print() + # Prepare the input original_question = question if not is_chat: @@ -65,10 +70,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap if type(st) is list and len(st) > 0: all_stop_strings += st - if shared.args.verbose: - logger.info("PROMPT=") - print(question) - shared.stop_everything = False clear_torch_cache() seed = set_manual_seed(state['seed']) @@ -355,6 +356,10 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(filtered_params) print() + logger.info("PROMPT=") + print(decode(input_ids[0], skip_special_tokens=False)) + print() + t0 = time.time() try: if not is_chat and not shared.is_seq2seq: