Add support for stopping strings in RWKV models

This commit is contained in:
v0xie 2023-04-17 02:07:19 -07:00
parent b57ffc2ec9
commit 7ff2cc8316

View File

@ -160,13 +160,19 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
else: else:
if not shared.is_chat(): if not shared.is_chat():
yield formatted_outputs(question, shared.model_name) yield formatted_outputs(question, shared.model_name)
evaluated_stopping_strings = ast.literal_eval(f"[{state['custom_stopping_strings']}]")
# RWKV has proper streaming, which is very nice. # RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time. # No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, **generate_params): for reply in shared.model.generate_with_streaming(context=question, **generate_params):
output = original_question + reply output = original_question + reply
original_reply = reply
if not shared.is_chat(): if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output') original_reply = apply_extensions(reply, 'output')
reply = original_question + original_reply
found_stop_string = next((s for s in stopping_strings + evaluated_stopping_strings if s in original_reply), None)
if found_stop_string:
yield formatted_outputs(str(reply).rsplit(found_stop_string, 1)[0], shared.model_name)
break
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
except Exception: except Exception: