From 0d36c18f5d9265eeba07139d3fdfc859e2fa680c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 11 May 2023 17:07:20 -0300 Subject: [PATCH] Always return only the new tokens in generation functions --- extensions/api/blocking_api.py | 2 +- extensions/api/streaming_api.py | 2 +- extensions/openai/script.py | 12 +++--------- modules/text_generation.py | 25 +++++++++++-------------- 4 files changed, 16 insertions(+), 25 deletions(-) diff --git a/extensions/api/blocking_api.py b/extensions/api/blocking_api.py index 114f6048..134e99d4 100644 --- a/extensions/api/blocking_api.py +++ b/extensions/api/blocking_api.py @@ -43,7 +43,7 @@ class Handler(BaseHTTPRequestHandler): response = json.dumps({ 'results': [{ - 'text': answer[len(prompt):] + 'text': answer }] }) self.wfile.write(response.encode('utf-8')) diff --git a/extensions/api/streaming_api.py b/extensions/api/streaming_api.py index 1fed73a9..e50dfa22 100644 --- a/extensions/api/streaming_api.py +++ b/extensions/api/streaming_api.py @@ -29,7 +29,7 @@ async def _handle_connection(websocket, path): prompt, generate_params, stopping_strings=stopping_strings, is_chat=False) # As we stream, only send the new bytes. - skip_index = len(prompt) + skip_index = 0 message_num = 0 for a in generator: diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 8685ff0b..90f7c273 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -340,17 +340,14 @@ class Handler(BaseHTTPRequestHandler): # generate reply ####################################### if debug: print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings}) - generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=True) + generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) answer = '' seen_content = '' longest_stop_len = max([len(x) for x in stopping_strings]) for a in generator: - if isinstance(a, str): - answer = a - else: - answer = a[0] + answer = a stop_string_found = False len_seen = len(seen_content) @@ -521,10 +518,7 @@ class Handler(BaseHTTPRequestHandler): answer = '' for a in generator: - if isinstance(a, str): - answer = a - else: - answer = a[0] + answer = a completion_token_count = len(encode(answer)[0]) diff --git a/modules/text_generation.py b/modules/text_generation.py index 0bba7129..2dabc2ba 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -104,18 +104,15 @@ def fix_galactica(s): def get_reply_from_output_ids(output_ids, input_ids, original_question, state, is_chat=False): if shared.model_type == 'HF_seq2seq': reply = decode(output_ids, state['skip_special_tokens']) - if not is_chat: - reply = apply_extensions('output', reply) else: new_tokens = len(output_ids) - len(input_ids[0]) reply = decode(output_ids[-new_tokens:], state['skip_special_tokens']) - if type(shared.tokenizer) is transformers.LlamaTokenizer: if len(original_question) > 0 and original_question[-1] not in [' ', '\n']: reply = ' ' + reply - if not is_chat: - reply = original_question + apply_extensions('output', reply) + if not is_chat: + reply = apply_extensions('output', reply) return reply @@ -149,6 +146,9 @@ def stop_everything_event(): def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=None): for reply in generate_reply(question, state, eos_token, stopping_strings, is_chat=False): + if shared.model_type not in ['HF_seq2seq']: + reply = reply + question + yield formatted_outputs(reply, shared.model_name) @@ -236,7 +236,7 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None, t0 = time.time() try: if not is_chat and shared.model_type != 'HF_seq2seq': - yield original_question + yield '' # Generate the entire reply at once. if not state['stream']: @@ -291,21 +291,18 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No t0 = time.time() try: if not is_chat: - yield question + yield '' if not state['stream']: reply = shared.model.generate(context=question, **generate_params) - output = original_question + reply if not is_chat: - reply = original_question + apply_extensions('output', reply) + reply = apply_extensions('output', reply) yield reply else: - for reply in shared.model.generate_with_streaming(context=question, **generate_params): - output = original_question + reply if not is_chat: - reply = original_question + apply_extensions('output', reply) + reply = apply_extensions('output', reply) yield reply @@ -314,7 +311,7 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No finally: t1 = time.time() original_tokens = len(encode(original_question)[0]) - new_tokens = len(encode(output)[0]) - original_tokens + new_tokens = len(encode(original_question + reply)[0]) - original_tokens print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return @@ -349,7 +346,7 @@ def generate_reply_flexgen(question, original_question, seed, state, eos_token=N t0 = time.time() try: if not is_chat: - yield question + yield '' # Generate the entire reply at once. if not state['stream']: