From 7ff2cc8316ec8d6f3a6299eaa79a795176944a35 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 17 Apr 2023 02:07:19 -0700 Subject: [PATCH] Add support for stopping strings in RWKV models --- modules/text_generation.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index 370130ed..d6f5a7aa 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -160,13 +160,19 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): else: if not shared.is_chat(): yield formatted_outputs(question, shared.model_name) - + evaluated_stopping_strings = ast.literal_eval(f"[{state['custom_stopping_strings']}]") # RWKV has proper streaming, which is very nice. # No need to generate 8 tokens at a time. for reply in shared.model.generate_with_streaming(context=question, **generate_params): output = original_question + reply + original_reply = reply if not shared.is_chat(): - reply = original_question + apply_extensions(reply, 'output') + original_reply = apply_extensions(reply, 'output') + reply = original_question + original_reply + found_stop_string = next((s for s in stopping_strings + evaluated_stopping_strings if s in original_reply), None) + if found_stop_string: + yield formatted_outputs(str(reply).rsplit(found_stop_string, 1)[0], shared.model_name) + break yield formatted_outputs(reply, shared.model_name) except Exception: