From 341e13503634a0debb684105f055e09772d16c6e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 12 Mar 2023 02:53:08 -0300 Subject: [PATCH] Various fixes in chat mode --- modules/callbacks.py | 1 + modules/chat.py | 16 ++++++---------- modules/text_generation.py | 29 +++++++++++++++-------------- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/modules/callbacks.py b/modules/callbacks.py index e0d1c988..faa4a5e9 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -64,6 +64,7 @@ class Iteratorize: ret = self.mfunc(callback=_callback, **self.kwargs) except ValueError: pass + clear_torch_cache() self.q.put(self.sentinel) if self.c_callback: self.c_callback(ret) diff --git a/modules/chat.py b/modules/chat.py index 69d81e94..f40f8299 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -115,18 +115,14 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical visible_text = visible_text.replace('\n', '
') text = apply_extensions(text, "input") + if custom_generate_chat_prompt is None: + prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) + else: + prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) + # Generate reply = '' for i in range(chat_generation_attempts): - - # The prompt needs to be generated here because, as the reply - # grows, it may become necessary to remove more old messages to - # fit into the 2048 tokens window. - if custom_generate_chat_prompt is None: - prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size-len(encode(' '+reply)[0])) - else: - prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size-len(encode(' '+reply)[0])) - for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): # Extracting the reply @@ -160,10 +156,10 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ if 'pygmalion' in shared.model_name.lower(): name1 = "You" + prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True) reply = '' for i in range(chat_generation_attempts): - prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size-len(encode(' '+reply)[0]), impersonate=True) for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True) if not substring_found: diff --git a/modules/text_generation.py b/modules/text_generation.py index 2460df4f..7966e126 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -92,21 +92,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # These models are not part of Hugging Face, so we handle them # separately and terminate the function call earlier if shared.is_RWKV: - if shared.args.no_stream: - reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) - yield formatted_outputs(reply, shared.model_name) - else: - yield formatted_outputs(question, shared.model_name) - # RWKV has proper streaming, which is very nice. - # No need to generate 8 tokens at a time. - for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k): + try: + if shared.args.no_stream: + reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) yield formatted_outputs(reply, shared.model_name) - - t1 = time.time() - output = encode(reply)[0] - input_ids = encode(question) - print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)") - return + else: + yield formatted_outputs(question, shared.model_name) + # RWKV has proper streaming, which is very nice. + # No need to generate 8 tokens at a time. + for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k): + yield formatted_outputs(reply, shared.model_name) + finally: + t1 = time.time() + output = encode(reply)[0] + input_ids = encode(question) + print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)") + return original_question = question if not (shared.args.chat or shared.args.cai_chat):